#include "vibeqc/periodic_fock.hpp"

#include "vibeqc/init.hpp"
#include "vibeqc/schwarz.hpp"
#include "vibeqc/thread_pool.hpp"

#include <libint2/engine.h>
#include <cmath>
#include <cstdint>
#include <unordered_map>
#include <utility>
#include <vector>

namespace vibeqc {

namespace {

// Local alias so existing code keeps using the short name.
inline std::vector<libint2::Shell> shift_shells(
    const libint2::BasisSet& shells, const Eigen::Vector3d& dr) {
    return shift_shells_to_cell(shells, dr);
}

// Hash an integer lattice index for unordered_map lookup.
struct LatticeIndexHash {
    std::size_t operator()(const Eigen::Vector3i& v) const noexcept {
        // FNV-1a-ish. Cells live in a small bounded box; any decent hash works.
        const std::uint64_t a = static_cast<std::uint64_t>(v[0] + 1024);
        const std::uint64_t b = static_cast<std::uint64_t>(v[1] + 1024);
        const std::uint64_t c = static_cast<std::uint64_t>(v[2] + 1024);
        return static_cast<std::size_t>((a * 2654435761ULL) ^
                                        (b * 40503ULL) ^ c);
    }
};
struct LatticeIndexEq {
    bool operator()(const Eigen::Vector3i& a,
                    const Eigen::Vector3i& b) const noexcept {
        return a[0] == b[0] && a[1] == b[1] && a[2] == b[2];
    }
};
using CellIndexMap =
    std::unordered_map<Eigen::Vector3i, int, LatticeIndexHash, LatticeIndexEq>;
using WeightLookup =
    std::unordered_map<Eigen::Vector3i, Eigen::MatrixXd, LatticeIndexHash, LatticeIndexEq>;

CellIndexMap build_cell_index_map(const std::vector<LatticeCell>& cells) {
    CellIndexMap m;
    m.reserve(cells.size() * 2);
    for (std::size_t i = 0; i < cells.size(); ++i) {
        m[cells[i].index] = static_cast<int>(i);
    }
    return m;
}

}  // namespace

JKMatrices build_jk_gamma_molecular_limit(const BasisSet& basis,
                                          const PeriodicSystem& system,
                                          const LatticeSumOptions& opts,
                                          const Eigen::MatrixXd& P,
                                          double omega) {
    ensure_libint_initialized();

    const auto& shells_ref = basis.libint();
    const int nbf = static_cast<int>(basis.nbasis());
    if (P.rows() != nbf || P.cols() != nbf) {
        throw std::runtime_error(
            "build_jk_gamma_molecular_limit: density shape mismatch");
    }
    if (omega < 0.0) {
        throw std::runtime_error(
            "build_jk_gamma_molecular_limit: omega must be non-negative");
    }

    // Cell list shared between both lattice-index sums. Consistent with
    // Phase 12a's convention that cutoff_bohr bounds the μν real-space
    // sum; the two indices share the same cell lattice here because in
    // the Γ-only molecular limit their cutoffs are effectively coupled
    // through the ν ↔ P-image symmetry.
    const auto cells = direct_lattice_cells(system, opts.cutoff_bohr);

    // When ω > 0 we use libint's erfc-screened Coulomb kernel
    // (erfc(ω·r_12)/r_12) — the short-range Ewald-split piece. The
    // scalar ω is passed to the engine via set_params.
    const libint2::Operator op = (omega > 0.0)
        ? libint2::Operator::erfc_coulomb
        : libint2::Operator::coulomb;
    libint2::Engine prototype(op, shells_ref.max_nprim(),
                              shells_ref.max_l(), 0);
    if (omega > 0.0) {
        prototype.set_params(omega);
    }
    auto engines = make_engine_pool(prototype);
    const auto shell2bf = shells_ref.shell2bf();
    const std::size_t nshells = shells_ref.size();

    // Pre-shift shells for every cell; amortise across the doubled loop.
    // Memory: O(N_c × n_shells × shell-size) — small. Shared across threads.
    std::vector<std::vector<libint2::Shell>> shells_at(cells.size());
    for (std::size_t c = 0; c < cells.size(); ++c) {
        shells_at[c] = shift_shells(shells_ref, cells[c].r_cart);
    }

    // ---- Cauchy–Schwarz pre-pass ------------------------------------------
    //
    // The Γ-only molecular-limit J/K kernel has loop shape (c_g, c_p) with
    // ν shifted by g and {λ, σ} both shifted by p. For each shell quartet:
    //
    //   J: (μ_0 ν_g | λ_p σ_p) — bra-pair displacement c_g, ket-pair 0.
    //   K: (μ_0 λ_p | ν_g σ_p) — bra-pair displacement c_p, ket-pair c_p−c_g.
    //
    // Schwarz: |⟨ab|cd⟩| ≤ Q_ab · Q_cd. Skip the quartet when the bound
    // (× density envelope) falls below ``opts.schwarz_threshold``.
    const double schwarz_thr = opts.schwarz_threshold;
    const bool screen = (schwarz_thr > 0.0);
    const auto Q = screen
        ? compute_schwarz_factors_per_cell(shells_ref, shells_at, prototype)
        : std::vector<std::vector<double>>{};
    const auto Q_max = screen
        ? max_q_per_cell(Q)
        : std::vector<double>{};
    const double D_max = screen
        ? std::max(1.0, std::max(std::fabs(P.maxCoeff()),
                                  std::fabs(P.minCoeff())))
        : 0.0;
    // Sx3: per-shell-pair density envelope for LinK-style screening.
    Eigen::MatrixXd Dpair;
    if (screen) {
        Dpair = Eigen::MatrixXd::Zero(
            static_cast<Eigen::Index>(nshells),
            static_cast<Eigen::Index>(nshells));
        for (std::size_t s = 0; s < nshells; ++s) {
            const auto bf_s = shell2bf[s];
            const auto n_s = shells_ref[s].size();
            for (std::size_t t = 0; t < nshells; ++t) {
                const auto bf_t = shell2bf[t];
                const auto n_t = shells_ref[t].size();
                double mx = 0.0;
                for (std::size_t i = 0; i < n_s; ++i)
                    for (std::size_t j = 0; j < n_t; ++j)
                        mx = std::max(mx, std::fabs(P(static_cast<Eigen::Index>(bf_s + i),
                                                        static_cast<Eigen::Index>(bf_t + j))));
                Dpair(static_cast<Eigen::Index>(s), static_cast<Eigen::Index>(t)) = mx;
            }
        }
    }
    CellIndexMap cell_index_map;
    if (screen) cell_index_map = build_cell_index_map(cells);
    // Q[0] indexes the zero-cell (Γ ket of J). It is always the first cell
    // returned by direct_lattice_cells, but we look it up via the index map
    // for safety.
    int c_zero_idx = -1;
    if (screen) {
        auto it = cell_index_map.find(Eigen::Vector3i(0, 0, 0));
        if (it != cell_index_map.end()) c_zero_idx = it->second;
    }
    const double q_zero_max =
        (screen && c_zero_idx >= 0) ? Q_max[c_zero_idx] : 0.0;

    // Parallelise the double cell loop over a flat (c_g, c_p) index. Each
    // iteration accumulates into thread-local J and K buffers; we reduce
    // across threads at the end. No write races without this because J
    // and K would otherwise be hot shared accumulators across all the
    // nested shell-quartet loops.
    const int n_c = static_cast<int>(cells.size());
    const int n_pairs = n_c * n_c;
    const int n_threads = omp_max_threads();
    std::vector<Eigen::MatrixXd> Jm_tls(
        n_threads, Eigen::MatrixXd::Zero(nbf, nbf));
    std::vector<Eigen::MatrixXd> Km_tls(
        n_threads, Eigen::MatrixXd::Zero(nbf, nbf));

    #pragma omp parallel for schedule(dynamic)
    for (int idx = 0; idx < n_pairs; ++idx) {
        const int c_g = idx / n_c;
        const int c_p = idx % n_c;
        const auto tid = static_cast<std::size_t>(omp_thread_index());
        auto& engine = engines[tid];
        const auto& buf = engine.results();
        auto& Jm_local = Jm_tls[tid];
        auto& Km_local = Km_tls[tid];

        const auto& shells_g = shells_at[c_g];       // ν shifted by g
        const auto& shells_p = shells_at[c_p];       // λ, σ shifted by g_λ

        // Cell index of c_p − c_g for the K ket-pair displacement.
        int c_pg_idx = -1;
        if (screen) {
            const Eigen::Vector3i pg = cells[c_p].index - cells[c_g].index;
            auto it = cell_index_map.find(pg);
            if (it != cell_index_map.end()) c_pg_idx = it->second;
        }

        // Cell-level Schwarz: skip the entire shell-quartet loop when
        // neither J nor K can possibly contribute above the threshold.
        if (screen) {
            const bool j_possible =
                (Q_max[c_g] * q_zero_max * D_max >= schwarz_thr);
            const double q_pg_max =
                (c_pg_idx >= 0) ? Q_max[c_pg_idx] : 0.0;
            const bool k_possible =
                (Q_max[c_p] * q_pg_max * D_max >= schwarz_thr);
            if (!j_possible && !k_possible) continue;
        }

        for (std::size_t s1 = 0; s1 < shells_ref.size(); ++s1) {
            const auto bf1 = shell2bf[s1];
            const auto n1 = shells_ref[s1].size();

            for (std::size_t s2 = 0; s2 < shells_g.size(); ++s2) {
                const auto bf2 = shell2bf[s2];
                const auto n2 = shells_g[s2].size();
                // J bra: (s1_0, s2_g). K ket: (s2_g, s4_p).
                const double q12_J = screen
                    ? Q[c_g][s1 * nshells + s2]
                    : 0.0;

                for (std::size_t s3 = 0; s3 < shells_p.size(); ++s3) {
                    const auto bf3 = shell2bf[s3];
                    const auto n3 = shells_p[s3].size();
                    // K bra: (s1_0, s3_p).
                    const double q13_K = screen
                        ? Q[c_p][s1 * nshells + s3]
                        : 0.0;

                    for (std::size_t s4 = 0; s4 < shells_p.size(); ++s4) {
                        const auto bf4 = shell2bf[s4];
                        const auto n4 = shells_p[s4].size();

                        // J: (μ_0 ν_g | λ_p σ_p) — ket pair at c=0.
                        bool do_J = true;
                        if (screen) {
                            if (c_zero_idx < 0) {
                                do_J = false;
                            } else {
                                const double q34_J =
                                    Q[c_zero_idx][s3 * nshells + s4];
                                // Sx3: shell-pair density envelope for the
                                // ket pair (λ_p σ_p) — the density element
                                // that multiplies (μ_0 ν_g | λ_p σ_p).
                                const double den34 = Dpair(
                                    static_cast<Eigen::Index>(s3),
                                    static_cast<Eigen::Index>(s4));
                                if (q12_J * q34_J * den34 < schwarz_thr)
                                    do_J = false;
                            }
                        }
                        if (do_J) {
                            engine.compute(shells_ref[s1], shells_g[s2],
                                           shells_p[s3], shells_p[s4]);
                            if (const double* blk = buf[0]) {
                                for (std::size_t i = 0; i < n1; ++i)
                                for (std::size_t j = 0; j < n2; ++j)
                                for (std::size_t k = 0; k < n3; ++k)
                                for (std::size_t l = 0; l < n4; ++l) {
                                    const double v = blk[
                                        ((i * n2 + j) * n3 + k) * n4 + l];
                                    Jm_local(bf1 + i, bf2 + j) +=
                                        P(bf3 + k, bf4 + l) * v;
                                }
                            }
                        }

                        // K: (μ_0 λ_p | ν_g σ_p) — ket pair at c_p − c_g.
                        bool do_K = true;
                        if (screen) {
                            if (c_pg_idx < 0) {
                                do_K = false;
                            } else {
                                const double q24_K =
                                    Q[c_pg_idx][s2 * nshells + s4];
                                // Sx3: shell-pair density envelope for the
                                // density pair (λ σ) — P(λ, σ) multiplies
                                // (μ_0 λ_p | ν_g σ_p) in the K build.
                                const double den34 = Dpair(
                                    static_cast<Eigen::Index>(s3),
                                    static_cast<Eigen::Index>(s4));
                                if (q13_K * q24_K * den34 * 0.5 < schwarz_thr)
                                    do_K = false;
                            }
                        }
                        if (do_K) {
                            engine.compute(shells_ref[s1], shells_p[s3],
                                           shells_g[s2], shells_p[s4]);
                            if (const double* blk = buf[0]) {
                                for (std::size_t i = 0; i < n1; ++i)
                                for (std::size_t j = 0; j < n2; ++j)
                                for (std::size_t k = 0; k < n3; ++k)
                                for (std::size_t l = 0; l < n4; ++l) {
                                    const double v = blk[
                                        ((i * n3 + k) * n2 + j) * n4 + l];
                                    Km_local(bf1 + i, bf2 + j) +=
                                        P(bf3 + k, bf4 + l) * v;
                                }
                            }
                        }
                    }
                }
            }
        }
    }

    // Reduce thread-local accumulators.
    Eigen::MatrixXd Jm = Eigen::MatrixXd::Zero(nbf, nbf);
    Eigen::MatrixXd Km = Eigen::MatrixXd::Zero(nbf, nbf);
    for (const auto& m : Jm_tls) Jm += m;
    for (const auto& m : Km_tls) Km += m;
    return JKMatrices{Jm, Km};
}

JKLatticeMatrixSets build_jk_2e_real_space_impl(
        const BasisSet& basis,
        const PeriodicSystem& system,
        const LatticeSumOptions& opts,
        const LatticeMatrixSet& P_real_space,
        const std::vector<LatticeCell>& cells,
        bool compute_exchange,
        double omega,
        // Phase SYM3b: when non-empty, build output (bra-cell c_g) blocks
        // ONLY for these cell indices; the internal c_λ/c_σ lattice sum stays
        // FULL over `cells`, so each computed block is exact. Other blocks are
        // left zero for the caller to fill by point-group reconstruction. An
        // empty list builds every cell (the original full behaviour).
        const std::vector<int>& output_indices = {},
        // Phase SYM3b shell-pair mask: when non-empty (and parallel to
        // `output_indices`), `output_shell_masks[oi]` is an nshells*nshells
        // row-major flag array; only output shell pairs (s1_home, s2_cg) with a
        // non-zero flag are built. This restricts the build to the
        // atom-pair-orbit *representative* sub-blocks (not just rep cells), the
        // finer reduction the point-group reconstruction actually needs. The
        // internal c_λ/c_σ/s3/s4 sum stays full, so each emitted sub-block is
        // exact. Empty → build every output shell pair of each output cell.
        const std::vector<std::vector<uint8_t>>& output_shell_masks = {}) {
    ensure_libint_initialized();

    const auto& shells_ref = basis.libint();
    const int nbf = static_cast<int>(basis.nbasis());
    if (P_real_space.nbf != nbf) {
        throw std::runtime_error(
            "build_jk_2e_real_space: density nbf mismatch");
    }
    if (omega < 0.0) {
        throw std::runtime_error(
            "build_jk_2e_real_space: omega must be non-negative");
    }

    const auto n_c = cells.size();

    // Pre-shift shells once per cell.
    auto shift_shells = [&](const Eigen::Vector3d& dr) {
        std::vector<libint2::Shell> out(shells_ref.begin(), shells_ref.end());
        for (auto& s : out) {
            s.O[0] += dr[0]; s.O[1] += dr[1]; s.O[2] += dr[2];
        }
        return out;
    };
    std::vector<std::vector<libint2::Shell>> shells_at(n_c);
    for (std::size_t c = 0; c < n_c; ++c) {
        shells_at[c] = shift_shells(cells[c].r_cart);
    }

    // Density cell lookup by integer index. We'll use it to fetch P(h)
    // where h = g_σ − g_λ. Store blocks flat and index by cell position.
    const auto p_cell_index = build_cell_index_map(P_real_space.cells);
    auto p_block = [&](int c_lam, int c_sig) -> const Eigen::MatrixXd* {
        const Eigen::Vector3i h = cells[c_sig].index - cells[c_lam].index;
        auto it = p_cell_index.find(h);
        if (it == p_cell_index.end()) return nullptr;
        return &P_real_space.blocks[it->second];
    };
    auto p_block_idx = [&](int c_lam, int c_sig) -> int {
        const Eigen::Vector3i h = cells[c_sig].index - cells[c_lam].index;
        auto it = p_cell_index.find(h);
        return (it == p_cell_index.end()) ? -1 : it->second;
    };

    // Optionally switch to the erfc-screened Coulomb kernel for Ewald
    // short-range ERIs (same convention as build_jk_gamma_molecular_limit).
    const libint2::Operator op = (omega > 0.0)
        ? libint2::Operator::erfc_coulomb
        : libint2::Operator::coulomb;
    libint2::Engine prototype(op, shells_ref.max_nprim(),
                              shells_ref.max_l(), 0);
    if (omega > 0.0) {
        prototype.set_params(omega);
    }
    auto engines = make_engine_pool(prototype);
    const auto shell2bf = shells_ref.shell2bf();
    const std::size_t nshells = shells_ref.size();

    // ---- Cauchy–Schwarz pre-pass ------------------------------------------
    //
    // |⟨μ ν_g | λ_λ σ_σ⟩|  ≤  Q[c_g][s1, s2] · Q[c_h][s3, s4]
    //
    // (c_h = c_σ − c_λ for J; for K the cell decomposition differs —
    // see below). Skip the libint quartet call when the bound × density
    // envelope falls below ``opts.schwarz_threshold``. Without this the
    // cost is O(n_c³ · n_shells⁴) libint quartet calls per Fock build —
    // O(10⁹) for LiH/STO-3G with cutoff 15 bohr; with screening it
    // drops to seconds.
    const double schwarz_thr = opts.schwarz_threshold;
    const bool screen = (schwarz_thr > 0.0);
    const auto Q = screen
        ? compute_schwarz_factors_per_cell(shells_ref, shells_at, prototype)
        : std::vector<std::vector<double>>{};
    const double D_max = screen ? density_envelope(P_real_space.blocks) : 0.0;
    // Cell-level Schwarz: ``Q_max[c] = max_{s_a, s_b} Q[c][s_a, s_b]``.
    // For any cell triple (c_g, c_lam, c_sig) the bound on *any* shell
    // quartet is at most ``Q_max[c_g] · Q_max[c_h] · D_max`` (J) or
    // ``Q_max[c_lam] · Q_max[c_kh] · D_max · 0.5·|alpha|`` (K). If both
    // fail the threshold we can skip the entire shell-quartet loop —
    // a 10-100× speedup on real crystals where many triples have
    // negligible shell-pair overlap.
    //
    // Sx3: Density-weighted (LinK-style) per-shell-pair screening.
    // Compute per-shell-pair |D| max for each density block so the
    // shell-quartet screening below uses the local density envelope
    // instead of the global D_max.  For insulators the off-diagonal
    // density blocks decay exponentially, so this buys genuine O(N)
    // exchange screening.
    std::vector<Eigen::MatrixXd> Dpair_per_block;
    if (screen) {
        Dpair_per_block.reserve(P_real_space.blocks.size());
        for (const auto& B : P_real_space.blocks) {
            Eigen::MatrixXd Dpair(nshells, nshells);
            if (B.size() == 0) {
                Dpair.setZero();
            } else {
                for (std::size_t s = 0; s < nshells; ++s) {
                    const auto bf_s = shell2bf[s];
                    const auto n_s = shells_ref[s].size();
                    for (std::size_t t = 0; t < nshells; ++t) {
                        const auto bf_t = shell2bf[t];
                        const auto n_t = shells_ref[t].size();
                        double mx = 0.0;
                        for (std::size_t i = 0; i < n_s; ++i)
                            for (std::size_t j = 0; j < n_t; ++j)
                                mx = std::max(mx, std::fabs(B(static_cast<Eigen::Index>(bf_s + i),
                                                                static_cast<Eigen::Index>(bf_t + j))));
                        Dpair(static_cast<Eigen::Index>(s), static_cast<Eigen::Index>(t)) = mx;
                    }
                }
            }
            Dpair_per_block.push_back(std::move(Dpair));
        }
    }
    const auto Q_max = screen
        ? max_q_per_cell(Q)
        : std::vector<double>{};
    // Map cell-index → position in `cells`, for K-displacement lookup
    // (c_sig.index − c_g.index may not be in `cells` — if so its Q is 0
    // by truncation and the K bound vanishes).
    CellIndexMap cell_index_map;
    if (screen) cell_index_map = build_cell_index_map(cells);

    // Allocate result components: one nbf × nbf block per cell in `cells`.
    LatticeMatrixSet J_set;
    J_set.nbf = nbf;
    J_set.cells = cells;  // same cell list as ν-shift
    J_set.blocks.assign(n_c, Eigen::MatrixXd::Zero(nbf, nbf));

    LatticeMatrixSet K_set;
    K_set.nbf = nbf;
    K_set.cells = cells;
    K_set.blocks.assign(n_c, Eigen::MatrixXd::Zero(nbf, nbf));

    // Parallelise over c_g — each thread writes to distinct J/K blocks,
    // so no synchronisation is needed even though the inner
    // (c_λ, c_σ, s1..s4) loops accumulate into those blocks.
    // Phase SYM3b: iterate the bra-cell loop over the caller's output subset
    // when supplied (point-group orbit representatives), else over all cells.
    const int n_c_i = static_cast<int>(n_c);
    const bool use_subset = !output_indices.empty();
    const int n_out = use_subset
        ? static_cast<int>(output_indices.size())
        : n_c_i;
    const bool use_shell_mask = !output_shell_masks.empty();
    if (use_shell_mask &&
        static_cast<int>(output_shell_masks.size()) != n_out) {
        throw std::runtime_error(
            "build_jk_2e_real_space: output_shell_masks must be parallel to "
            "output_indices (one nshells*nshells mask per output cell)");
    }
    #pragma omp parallel for schedule(dynamic)
    for (int oi = 0; oi < n_out; ++oi) {
        const int c_g = use_subset ? output_indices[oi] : oi;
        const std::vector<uint8_t>* mask_oi =
            use_shell_mask ? &output_shell_masks[oi] : nullptr;
        auto& engine = engines[static_cast<std::size_t>(omp_thread_index())];
        const auto& buf = engine.results();

        const auto& shells_g = shells_at[c_g];
        Eigen::MatrixXd& J_g = J_set.blocks[c_g];
        Eigen::MatrixXd& K_g = K_set.blocks[c_g];

        for (std::size_t c_lam = 0; c_lam < n_c; ++c_lam) {
            const auto& shells_lam = shells_at[c_lam];

            for (std::size_t c_sig = 0; c_sig < n_c; ++c_sig) {
                const auto& shells_sig = shells_at[c_sig];

                const Eigen::MatrixXd* P_block =
                    p_block(static_cast<int>(c_lam),
                            static_cast<int>(c_sig));
                if (!P_block) continue;  // h = g_σ − g_λ outside density cutoff

                // Index of the c_h = c_σ − c_λ cell in `cells`, for the J
                // Schwarz bound. (Same displacement that p_block uses to
                // index P_real_space; here we want it inside `cells`.)
                int c_h_idx = -1;
                if (screen) {
                    const Eigen::Vector3i h =
                        cells[c_sig].index - cells[c_lam].index;
                    auto it = cell_index_map.find(h);
                    if (it != cell_index_map.end()) c_h_idx = it->second;
                }
                // Cell index for the K bound's second pair displacement
                // (c_σ − c_g): the K integral pairs (s2_g, s4_σ) so the
                // shell-pair displacement seen by Q is c_σ − c_g.
                int c_kh_idx = -1;
                if (screen && compute_exchange) {
                    const Eigen::Vector3i kh =
                        cells[c_sig].index - cells[c_g].index;
                    auto it = cell_index_map.find(kh);
                    if (it != cell_index_map.end()) c_kh_idx = it->second;
                }

                // ---- Cell-level Schwarz: skip the whole shell-quartet
                // loop if neither J nor K can possibly contribute above
                // the threshold. ----
                if (screen) {
                    const bool j_possible = (c_h_idx >= 0) &&
                        (Q_max[c_g] * Q_max[c_h_idx] * D_max >= schwarz_thr);
                    const bool k_possible = compute_exchange &&
                        (c_kh_idx >= 0) &&
                        (Q_max[c_lam] * Q_max[c_kh_idx] * D_max *
                            0.5 >= schwarz_thr);
                    if (!j_possible && !k_possible) continue;
                }

                // For each AO shell quartet, compute J and K integrals.
                for (std::size_t s1 = 0; s1 < shells_ref.size(); ++s1) {
                    const auto bf1 = shell2bf[s1];
                    const auto n1 = shells_ref[s1].size();
                    for (std::size_t s2 = 0; s2 < shells_g.size(); ++s2) {
                        const auto bf2 = shell2bf[s2];
                        const auto n2 = shells_g[s2].size();
                        // SYM3b shell-pair mask: skip output shell pairs
                        // (s1_home, s2_cg) that are not orbit representatives;
                        // the caller reconstructs them by point-group rotation.
                        if (mask_oi && (*mask_oi)[s1 * nshells + s2] == 0)
                            continue;
                        // Schwarz factor for the (s1_0, s2_g) pair.
                        const double q12 = screen
                            ? Q[c_g][s1 * nshells + s2]
                            : 0.0;
                        for (std::size_t s3 = 0; s3 < shells_lam.size(); ++s3) {
                            const auto bf3 = shell2bf[s3];
                            const auto n3 = shells_lam[s3].size();
                            // Schwarz factor for the (s1_0, s3_λ) pair —
                            // K integral's first half (relative displacement
                            // = c_λ).
                            const double q13 = (screen && compute_exchange)
                                ? Q[c_lam][s1 * nshells + s3]
                                : 0.0;
                            for (std::size_t s4 = 0; s4 < shells_sig.size(); ++s4) {
                                const auto bf4 = shell2bf[s4];
                                const auto n4 = shells_sig[s4].size();

                                // -- J piece -----------------------------
                                bool do_J = true;
                                if (screen) {
                                    if (c_h_idx < 0) {
                                        do_J = false;
                                    } else {
                                        const double q34 =
                                            Q[c_h_idx][s3 * nshells + s4];
                                        // Sx3: shell-pair density max from
                                        // the local density block.
                                        const int p_idx = p_block_idx(
                                            static_cast<int>(c_lam),
                                            static_cast<int>(c_sig));
                                        const double den12 = (p_idx >= 0)
                                            ? Dpair_per_block[p_idx](
                                                static_cast<Eigen::Index>(s3),
                                                static_cast<Eigen::Index>(s4))
                                            : D_max;
                                        if (q12 * q34 * den12 < schwarz_thr)
                                            do_J = false;
                                    }
                                }
                                if (do_J) {
                                    engine.compute(shells_ref[s1], shells_g[s2],
                                                   shells_lam[s3], shells_sig[s4]);
                                    if (const double* blk = buf[0]) {
                                        for (std::size_t i = 0; i < n1; ++i)
                                        for (std::size_t j = 0; j < n2; ++j)
                                        for (std::size_t k = 0; k < n3; ++k)
                                        for (std::size_t l = 0; l < n4; ++l) {
                                            const double v = blk[
                                                ((i * n2 + j) * n3 + k) * n4 + l];
                                            J_g(bf1 + i, bf2 + j) +=
                                                (*P_block)(bf3 + k, bf4 + l) * v;
                                        }
                                    }
                                }

                                // -- K piece (exchange) -----------------
                                if (compute_exchange) {
                                    bool do_K = true;
                                    if (screen) {
                                        if (c_kh_idx < 0) {
                                            do_K = false;
                                        } else {
                                            const double q24 =
                                                Q[c_kh_idx][s2 * nshells + s4];
                                            // Sx3: shell-pair density
                                            // envelope for P(λ,σ) — the
                                            // density element that multiplies
                                            // (μ_0 λ_λ | ν_g σ_σ) in the K
                                            // build.
                                            const int p_idx = p_block_idx(
                                                static_cast<int>(c_lam),
                                                static_cast<int>(c_sig));
                                            const double den34 = (p_idx >= 0)
                                                ? Dpair_per_block[p_idx](
                                                    static_cast<Eigen::Index>(s3),
                                                    static_cast<Eigen::Index>(s4))
                                                : D_max;
                                            if (q13 * q24 * den34 * 0.5
                                                    < schwarz_thr)
                                                do_K = false;
                                        }
                                    }
                                    if (do_K) {
                                        engine.compute(shells_ref[s1], shells_lam[s3],
                                                       shells_g[s2], shells_sig[s4]);
                                        if (const double* blk = buf[0]) {
                                            for (std::size_t i = 0; i < n1; ++i)
                                            for (std::size_t j = 0; j < n2; ++j)
                                            for (std::size_t k = 0; k < n3; ++k)
                                            for (std::size_t l = 0; l < n4; ++l) {
                                                const double v = blk[
                                                    ((i * n3 + k) * n2 + j) * n4 + l];
                                                K_g(bf1 + i, bf2 + j) +=
                                                    (*P_block)(bf3 + k, bf4 + l) * v;
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }
    return JKLatticeMatrixSets{std::move(J_set), std::move(K_set)};
}

JKLatticeMatrixSets build_jk_2e_real_space(const BasisSet& basis,
                                           const PeriodicSystem& system,
                                           const LatticeSumOptions& opts,
                                           const LatticeMatrixSet& P_real_space,
                                           double omega) {
    const auto cells = direct_lattice_cells(system, opts.cutoff_bohr);
    return build_jk_2e_real_space_impl(
        basis, system, opts, P_real_space, cells, true, omega);
}

JKLatticeMatrixSets build_jk_2e_real_space_explicit(
        const BasisSet& basis,
        const PeriodicSystem& system,
        const LatticeSumOptions& opts,
        const LatticeMatrixSet& P_real_space,
        const std::vector<LatticeCell>& cells,
        double omega) {
    return build_jk_2e_real_space_impl(
        basis, system, opts, P_real_space, cells, true, omega);
}

// Phase SYM3b: symmetry-reduced output build. Builds J(g)/K(g) only for the
// bra cells in `output_indices` (point-group orbit representatives), with the
// FULL internal lattice sum over the cutoff cell list — so each emitted block
// is exact and the non-emitted blocks (left zero) are recovered by the caller
// via point-group reconstruction. The |G|-fold compute reduction the SYM3
// storage round-trip (vibeqc.symmetry_integrals) was the validation substrate
// for. `output_indices` are positions into the cutoff cell list.
JKLatticeMatrixSets build_jk_2e_real_space_output_subset(
        const BasisSet& basis,
        const PeriodicSystem& system,
        const LatticeSumOptions& opts,
        const LatticeMatrixSet& P_real_space,
        const std::vector<int>& output_indices,
        double omega) {
    const auto cells = direct_lattice_cells(system, opts.cutoff_bohr);
    return build_jk_2e_real_space_impl(
        basis, system, opts, P_real_space, cells, true, omega, output_indices);
}

// Phase SYM3b shell-pair mask: like build_jk_2e_real_space_output_subset, but
// builds only the atom-pair-orbit *representative* sub-blocks within each
// output (representative) cell, given a per-output-cell nshells*nshells flag
// array (`output_shell_masks[oi][s1*nshells + s2] != 0` → build that output
// shell pair). This is the finer reduction: the whole-cell subset still builds
// every shell pair of a rep cell, whereas the reconstruction only needs the
// rep atom-pair sub-blocks. Internal lattice sum stays full → each emitted
// sub-block is exact; non-emitted entries are recovered by point-group
// rotation. `output_indices` are positions into the cutoff cell list.
JKLatticeMatrixSets build_jk_2e_real_space_output_subset_masked(
        const BasisSet& basis,
        const PeriodicSystem& system,
        const LatticeSumOptions& opts,
        const LatticeMatrixSet& P_real_space,
        const std::vector<int>& output_indices,
        const std::vector<std::vector<uint8_t>>& output_shell_masks,
        double omega) {
    const auto cells = direct_lattice_cells(system, opts.cutoff_bohr);
    return build_jk_2e_real_space_impl(
        basis, system, opts, P_real_space, cells, true, omega, output_indices,
        output_shell_masks);
}

LatticeMatrixSet build_fock_2e_real_space(const BasisSet& basis,
                                          const PeriodicSystem& system,
                                          const LatticeSumOptions& opts,
                                          const LatticeMatrixSet& P_real_space,
                                          double exchange_scale,
                                          double omega) {
    const bool need_exchange = (exchange_scale != 0.0);
    if (omega < 0.0) {
        throw std::runtime_error(
            "build_fock_2e_real_space: omega must be non-negative");
    }
    const auto cells = direct_lattice_cells(system, opts.cutoff_bohr);
    if (!need_exchange) {
        return build_jk_2e_real_space_impl(
            basis, system, opts, P_real_space, cells, false, omega).J;
    }
    JKLatticeMatrixSets jk = build_jk_2e_real_space_impl(
        basis, system, opts, P_real_space, cells, true, omega);
    for (std::size_t c = 0; c < jk.J.blocks.size(); ++c) {
        jk.J.blocks[c] -= 0.5 * exchange_scale * jk.K.blocks[c];
    }
    return std::move(jk.J);
}

// ---- Phase SYM3b explicit-cell variant -----------------------------------
// Same kernel as build_jk_gamma_molecular_limit, but accepts a caller-
// supplied cell list.  The original entry point delegates here.
JKMatrices build_jk_gamma_molecular_limit_explicit(
        const BasisSet& basis,
        const PeriodicSystem& system,
        const std::vector<LatticeCell>& cells,
        const LatticeSumOptions& opts,
        const Eigen::MatrixXd& P,
        double omega) {
    ensure_libint_initialized();

    const auto& shells_ref = basis.libint();
    const int nbf = static_cast<int>(basis.nbasis());
    if (P.rows() != nbf || P.cols() != nbf)
        throw std::runtime_error("build_jk_gamma_molecular_limit_explicit: density shape mismatch");
    if (omega < 0.0)
        throw std::runtime_error("build_jk_gamma_molecular_limit_explicit: omega must be non-negative");

    const libint2::Operator op = (omega > 0.0)
        ? libint2::Operator::erfc_coulomb
        : libint2::Operator::coulomb;
    libint2::Engine prototype(op, shells_ref.max_nprim(), shells_ref.max_l(), 0);
    if (omega > 0.0) prototype.set_params(omega);
    auto engines = make_engine_pool(prototype);
    const auto shell2bf = shells_ref.shell2bf();
    const std::size_t nshells = shells_ref.size();

    std::vector<std::vector<libint2::Shell>> shells_at(cells.size());
    for (std::size_t c = 0; c < cells.size(); ++c)
        shells_at[c] = shift_shells(shells_ref, cells[c].r_cart);

    const double schwarz_thr = opts.schwarz_threshold;
    const bool screen = (schwarz_thr > 0.0);
    const auto Q = screen
        ? compute_schwarz_factors_per_cell(shells_ref, shells_at, prototype)
        : std::vector<std::vector<double>>{};
    const auto Q_max = screen ? max_q_per_cell(Q) : std::vector<double>{};
    const double D_max = screen
        ? std::max(1.0, std::max(std::fabs(P.maxCoeff()), std::fabs(P.minCoeff())))
        : 0.0;
    CellIndexMap cell_index_map;
    if (screen) cell_index_map = build_cell_index_map(cells);
    int c_zero_idx = -1;
    if (screen) {
        auto it = cell_index_map.find(Eigen::Vector3i(0, 0, 0));
        if (it != cell_index_map.end()) c_zero_idx = it->second;
    }
    const double q_zero_max = (screen && c_zero_idx >= 0) ? Q_max[c_zero_idx] : 0.0;

    const int n_c = static_cast<int>(cells.size());
    const int n_pairs = n_c * n_c;
    const int n_threads = omp_max_threads();
    std::vector<Eigen::MatrixXd> Jm_tls(n_threads, Eigen::MatrixXd::Zero(nbf, nbf));
    std::vector<Eigen::MatrixXd> Km_tls(n_threads, Eigen::MatrixXd::Zero(nbf, nbf));

    #pragma omp parallel for schedule(dynamic)
    for (int idx = 0; idx < n_pairs; ++idx) {
        const int c_g = idx / n_c;
        const int c_p = idx % n_c;
        const auto tid = static_cast<std::size_t>(omp_thread_index());
        auto& engine = engines[tid];
        const auto& buf = engine.results();
        auto& Jm_local = Jm_tls[tid];
        auto& Km_local = Km_tls[tid];
        const auto& shells_g = shells_at[c_g];
        const auto& shells_p = shells_at[c_p];
        int c_pg_idx = -1;
        if (screen) {
            const Eigen::Vector3i pg = cells[c_p].index - cells[c_g].index;
            auto it = cell_index_map.find(pg);
            if (it != cell_index_map.end()) c_pg_idx = it->second;
        }
        if (screen) {
            const bool j_possible = (Q_max[c_g] * q_zero_max * D_max >= schwarz_thr);
            const double q_pg_max = (c_pg_idx >= 0) ? Q_max[c_pg_idx] : 0.0;
            const bool k_possible = (Q_max[c_p] * q_pg_max * D_max >= schwarz_thr);
            if (!j_possible && !k_possible) continue;
        }
        for (std::size_t s1 = 0; s1 < shells_ref.size(); ++s1) {
            const auto bf1 = shell2bf[s1]; const auto n1 = shells_ref[s1].size();
            for (std::size_t s2 = 0; s2 < shells_g.size(); ++s2) {
                const auto bf2 = shell2bf[s2]; const auto n2 = shells_g[s2].size();
                const double q12_J = screen ? Q[c_g][s1 * nshells + s2] : 0.0;
                for (std::size_t s3 = 0; s3 < shells_p.size(); ++s3) {
                    const auto bf3 = shell2bf[s3]; const auto n3 = shells_p[s3].size();
                    const double q13_K = screen ? Q[c_p][s1 * nshells + s3] : 0.0;
                    for (std::size_t s4 = 0; s4 < shells_p.size(); ++s4) {
                        const auto bf4 = shell2bf[s4]; const auto n4 = shells_p[s4].size();
                        bool do_J = true;
                        if (screen) {
                            if (c_zero_idx < 0) { do_J = false; }
                            else {
                                const double q34_J = Q[c_zero_idx][s3 * nshells + s4];
                                if (q12_J * q34_J * D_max < schwarz_thr) do_J = false;
                            }
                        }
                        if (do_J) {
                            engine.compute(shells_ref[s1], shells_g[s2], shells_p[s3], shells_p[s4]);
                            if (const double* blk = buf[0]) {
                                for (std::size_t i = 0; i < n1; ++i)
                                for (std::size_t j = 0; j < n2; ++j)
                                for (std::size_t k = 0; k < n3; ++k)
                                for (std::size_t l = 0; l < n4; ++l) {
                                    const double v = blk[((i * n2 + j) * n3 + k) * n4 + l];
                                    Jm_local(bf1+i, bf2+j) += P(bf3+k, bf4+l) * v;
                                }
                            }
                        }
                        bool do_K = true;
                        if (screen) {
                            if (c_pg_idx < 0) { do_K = false; }
                            else {
                                const double q24_K = Q[c_pg_idx][s2 * nshells + s4];
                                if (q13_K * q24_K * D_max < schwarz_thr) do_K = false;
                            }
                        }
                        if (do_K) {
                            engine.compute(shells_ref[s1], shells_p[s3], shells_g[s2], shells_p[s4]);
                            if (const double* blk = buf[0]) {
                                for (std::size_t i = 0; i < n1; ++i)
                                for (std::size_t j = 0; j < n2; ++j)
                                for (std::size_t k = 0; k < n3; ++k)
                                for (std::size_t l = 0; l < n4; ++l) {
                                    const double v = blk[((i * n3 + k) * n2 + j) * n4 + l];
                                    Km_local(bf1+i, bf2+j) += P(bf3+k, bf4+l) * v;
                                }
                            }
                        }
                    }
                }
            }
        }
    }
    Eigen::MatrixXd Jm = Eigen::MatrixXd::Zero(nbf, nbf);
    Eigen::MatrixXd Km = Eigen::MatrixXd::Zero(nbf, nbf);
    for (const auto& m : Jm_tls) Jm += m;
    for (const auto& m : Km_tls) Km += m;
    return JKMatrices{Jm, Km};
}

// ---- Phase M3b per-cell-pair J/K contributions ---------------------------
// Same Γ-only molecular-limit shell-quartet loop as
// build_jk_gamma_molecular_limit_explicit, but each (c_g, c_p) pair's J and K
// contributions are stored separately instead of being summed. Parallelised
// over the supplied pair list; each thread writes to its own output slot, so
// no reduction or locking is needed.
std::vector<PairJKContribution> build_jk_pair_contributions(
        const BasisSet& basis,
        const PeriodicSystem& system,
        const std::vector<LatticeCell>& cells,
        const std::vector<std::pair<int, int>>& pairs,
        const LatticeSumOptions& opts,
        const Eigen::MatrixXd& P,
        double omega) {
    ensure_libint_initialized();

    const auto& shells_ref = basis.libint();
    const int nbf = static_cast<int>(basis.nbasis());
    if (P.rows() != nbf || P.cols() != nbf)
        throw std::runtime_error("build_jk_pair_contributions: density shape mismatch");
    if (omega < 0.0)
        throw std::runtime_error("build_jk_pair_contributions: omega must be non-negative");

    const int n_c = static_cast<int>(cells.size());
    for (const auto& pr : pairs) {
        if (pr.first < 0 || pr.first >= n_c ||
            pr.second < 0 || pr.second >= n_c)
            throw std::runtime_error(
                "build_jk_pair_contributions: pair cell index out of range");
    }

    const libint2::Operator op = (omega > 0.0)
        ? libint2::Operator::erfc_coulomb
        : libint2::Operator::coulomb;
    libint2::Engine prototype(op, shells_ref.max_nprim(), shells_ref.max_l(), 0);
    if (omega > 0.0) prototype.set_params(omega);
    auto engines = make_engine_pool(prototype);
    const auto shell2bf = shells_ref.shell2bf();
    const std::size_t nshells = shells_ref.size();

    std::vector<std::vector<libint2::Shell>> shells_at(cells.size());
    for (std::size_t c = 0; c < cells.size(); ++c)
        shells_at[c] = shift_shells(shells_ref, cells[c].r_cart);

    const double schwarz_thr = opts.schwarz_threshold;
    const bool screen = (schwarz_thr > 0.0);
    const auto Q = screen
        ? compute_schwarz_factors_per_cell(shells_ref, shells_at, prototype)
        : std::vector<std::vector<double>>{};
    const auto Q_max = screen ? max_q_per_cell(Q) : std::vector<double>{};
    const double D_max = screen
        ? std::max(1.0, std::max(std::fabs(P.maxCoeff()), std::fabs(P.minCoeff())))
        : 0.0;
    CellIndexMap cell_index_map;
    if (screen) cell_index_map = build_cell_index_map(cells);
    int c_zero_idx = -1;
    if (screen) {
        auto it = cell_index_map.find(Eigen::Vector3i(0, 0, 0));
        if (it != cell_index_map.end()) c_zero_idx = it->second;
    }
    const double q_zero_max = (screen && c_zero_idx >= 0) ? Q_max[c_zero_idx] : 0.0;

    // Pre-allocate one output slot per pair; each thread owns a distinct slot.
    const int n_pairs = static_cast<int>(pairs.size());
    std::vector<PairJKContribution> out(pairs.size());
    for (int i = 0; i < n_pairs; ++i) {
        out[i].c_g = pairs[i].first;
        out[i].c_p = pairs[i].second;
        out[i].J_contrib = Eigen::MatrixXd::Zero(nbf, nbf);
        out[i].K_contrib = Eigen::MatrixXd::Zero(nbf, nbf);
    }

    #pragma omp parallel for schedule(dynamic)
    for (int i = 0; i < n_pairs; ++i) {
        const int c_g = pairs[i].first;
        const int c_p = pairs[i].second;
        const auto tid = static_cast<std::size_t>(omp_thread_index());
        auto& engine = engines[tid];
        const auto& buf = engine.results();
        Eigen::MatrixXd& Jm_local = out[i].J_contrib;
        Eigen::MatrixXd& Km_local = out[i].K_contrib;
        const auto& shells_g = shells_at[c_g];
        const auto& shells_p = shells_at[c_p];
        int c_pg_idx = -1;
        if (screen) {
            const Eigen::Vector3i pg = cells[c_p].index - cells[c_g].index;
            auto it = cell_index_map.find(pg);
            if (it != cell_index_map.end()) c_pg_idx = it->second;
        }
        if (screen) {
            const bool j_possible = (Q_max[c_g] * q_zero_max * D_max >= schwarz_thr);
            const double q_pg_max = (c_pg_idx >= 0) ? Q_max[c_pg_idx] : 0.0;
            const bool k_possible = (Q_max[c_p] * q_pg_max * D_max >= schwarz_thr);
            if (!j_possible && !k_possible) continue;
        }
        for (std::size_t s1 = 0; s1 < shells_ref.size(); ++s1) {
            const auto bf1 = shell2bf[s1]; const auto n1 = shells_ref[s1].size();
            for (std::size_t s2 = 0; s2 < shells_g.size(); ++s2) {
                const auto bf2 = shell2bf[s2]; const auto n2 = shells_g[s2].size();
                const double q12_J = screen ? Q[c_g][s1 * nshells + s2] : 0.0;
                for (std::size_t s3 = 0; s3 < shells_p.size(); ++s3) {
                    const auto bf3 = shell2bf[s3]; const auto n3 = shells_p[s3].size();
                    const double q13_K = screen ? Q[c_p][s1 * nshells + s3] : 0.0;
                    for (std::size_t s4 = 0; s4 < shells_p.size(); ++s4) {
                        const auto bf4 = shell2bf[s4]; const auto n4 = shells_p[s4].size();
                        bool do_J = true;
                        if (screen) {
                            if (c_zero_idx < 0) { do_J = false; }
                            else {
                                const double q34_J = Q[c_zero_idx][s3 * nshells + s4];
                                if (q12_J * q34_J * D_max < schwarz_thr) do_J = false;
                            }
                        }
                        if (do_J) {
                            engine.compute(shells_ref[s1], shells_g[s2], shells_p[s3], shells_p[s4]);
                            if (const double* blk = buf[0]) {
                                for (std::size_t i_ = 0; i_ < n1; ++i_)
                                for (std::size_t j = 0; j < n2; ++j)
                                for (std::size_t k = 0; k < n3; ++k)
                                for (std::size_t l = 0; l < n4; ++l) {
                                    const double v = blk[((i_ * n2 + j) * n3 + k) * n4 + l];
                                    Jm_local(bf1+i_, bf2+j) += P(bf3+k, bf4+l) * v;
                                }
                            }
                        }
                        bool do_K = true;
                        if (screen) {
                            if (c_pg_idx < 0) { do_K = false; }
                            else {
                                const double q24_K = Q[c_pg_idx][s2 * nshells + s4];
                                if (q13_K * q24_K * D_max < schwarz_thr) do_K = false;
                            }
                        }
                        if (do_K) {
                            engine.compute(shells_ref[s1], shells_p[s3], shells_g[s2], shells_p[s4]);
                            if (const double* blk = buf[0]) {
                                for (std::size_t i_ = 0; i_ < n1; ++i_)
                                for (std::size_t j = 0; j < n2; ++j)
                                for (std::size_t k = 0; k < n3; ++k)
                                for (std::size_t l = 0; l < n4; ++l) {
                                    const double v = blk[((i_ * n3 + k) * n2 + j) * n4 + l];
                                    Km_local(bf1+i_, bf2+j) += P(bf3+k, bf4+l) * v;
                                }
                            }
                        }
                    }
                }
            }
        }
    }
    return out;
}


// ---- CCM (Cyclic Cluster Model) WSSC-weighted four-center J/K build ------
// Two methods, selectable via `method`:
//   "bra_home" — bra at home, ket imaged, bra-ket symmetrised via a
//        second ERI pass (lambda-home, sigma-home | mu_{-p}, nu_{-p}).
//        Matches vibe-qc's padded ccm_eri frame, validated to ~1e-5.
//   "union12"  — nu imaged over WSC(mu), lambda in union WSC, K
//        symmetrised per spec (AICCM reference frame, research).
JKMatrices build_jk_ccm_weighted(
    const BasisSet& basis,
    const PeriodicSystem& system,
    const std::vector<LatticeCell>& cells,
    const LatticeSumOptions& opts,
    const Eigen::MatrixXd& P_gamma,
    const std::vector<Eigen::Vector3i>& weight_cells,
    const std::vector<Eigen::MatrixXd>& weight_matrices,
    const std::string& method,
    double omega) {
    ensure_libint_initialized();

    const auto& shells_ref = basis.libint();
    const int nbf = static_cast<int>(basis.nbasis());

    const auto shell_info = basis.shells();
    std::vector<int> atom_of_shell(shell_info.size());
    for (std::size_t s = 0; s < shell_info.size(); ++s)
        atom_of_shell[s] = shell_info[s].atom_index;

    if (weight_matrices.empty())
        throw std::runtime_error("build_jk_ccm_weighted: empty weights");
    const int n_atoms = static_cast<int>(weight_matrices[0].rows());
    if (P_gamma.rows() != nbf || P_gamma.cols() != nbf)
        throw std::runtime_error("build_jk_ccm_weighted: density shape mismatch");
    if (omega < 0.0)
        throw std::runtime_error("build_jk_ccm_weighted: omega >= 0 required");
    if (weight_cells.size() != weight_matrices.size())
        throw std::runtime_error("build_jk_ccm_weighted: weight size mismatch");

    WeightLookup wlookup;
    wlookup.reserve(weight_cells.size());
    for (std::size_t wi = 0; wi < weight_cells.size(); ++wi) {
        const auto& W = weight_matrices[wi];
        if (W.rows() != n_atoms || W.cols() != n_atoms)
            throw std::runtime_error("build_jk_ccm_weighted: weight shape mismatch");
        wlookup[weight_cells[wi]] = W;
    }
    auto w = [&](const Eigen::Vector3i& g, int A, int B) -> double {
        auto it = wlookup.find(g);
        if (it == wlookup.end()) return 0.0;
        if (A < 0 || A >= n_atoms || B < 0 || B >= n_atoms) return 0.0;
        return it->second(A, B);
    };

    const libint2::Operator op = (omega > 0.0)
        ? libint2::Operator::erfc_coulomb : libint2::Operator::coulomb;
    libint2::Engine prototype(op, shells_ref.max_nprim(), shells_ref.max_l(), 0);
    if (omega > 0.0) prototype.set_params(omega);
    auto engines = make_engine_pool(prototype);
    const auto shell2bf = shells_ref.shell2bf();
    const std::size_t nshells = shells_ref.size();

    std::vector<std::vector<libint2::Shell>> shells_at(cells.size());
    for (std::size_t c = 0; c < cells.size(); ++c)
        shells_at[c] = shift_shells(shells_ref, cells[c].r_cart);

    const double schwarz_thr = opts.schwarz_threshold;
    const bool screen = (schwarz_thr > 0.0);
    const auto Q = screen
        ? compute_schwarz_factors_per_cell(shells_ref, shells_at, prototype)
        : std::vector<std::vector<double>>{};
    const auto Q_max = screen ? max_q_per_cell(Q) : std::vector<double>{};
    const double D_max = screen
        ? std::max(1.0, std::max(std::fabs(P_gamma.maxCoeff()),
                                  std::fabs(P_gamma.minCoeff()))) : 0.0;
    CellIndexMap cell_index_map;
    if (screen) cell_index_map = build_cell_index_map(cells);
    int c_zero_idx = -1;
    if (screen) {
        auto it = cell_index_map.find(Eigen::Vector3i(0, 0, 0));
        if (it != cell_index_map.end()) c_zero_idx = it->second;
    }

    const int n_c = static_cast<int>(cells.size());
    const int n_threads = omp_max_threads();
    const Eigen::Vector3i home(0, 0, 0);

    // Find home cell.
    int home_idx = -1;
    for (int ci = 0; ci < n_c; ++ci)
        if (cells[ci].index == home) { home_idx = ci; break; }
    if (home_idx < 0)
        throw std::runtime_error("build_jk_ccm_weighted: home cell not found");
    const auto& shells_home = shells_at[home_idx];

    // Pre-build map from cell index to position in cells array (for
    // looking up the negated cell -c_p).
    std::unordered_map<Eigen::Vector3i, int, LatticeIndexHash, LatticeIndexEq> cell_pos_map;
    for (int ci = 0; ci < n_c; ++ci)
        cell_pos_map[cells[ci].index] = ci;

    if (method == "bra_home") {
        // ================================================================
        // bra_home: ket-folded + bra-folded with independent sigma offset.
        //
        // Ket-folded: w_Jk = W0[a,b] * (Wp[a,c]+Wp[b,c])/2 * W{gd}[c,d]
        //             ERI = (mu_0 nu_0 | lambda_p sigma_q)  [q = p + gd]
        // Bra-folded: w_Jb = W{gd}[c,d] * (W{-p}[c,a]+W{-q}[d,a])/2 * W0[a,b]
        //             ERI = (lambda_0 sigma_0 | mu_{-p} nu_{-q})
        // J = 0.5*(Jk + Jb), same for K.
        // ================================================================

        // Pre-build sorted list of weight-cell offsets (gd values).
        std::vector<Eigen::Vector3i> gd_list(weight_cells.begin(), weight_cells.end());

        std::vector<Eigen::MatrixXd> Jk_tls(n_threads, Eigen::MatrixXd::Zero(nbf, nbf));
        std::vector<Eigen::MatrixXd> Kk_tls(n_threads, Eigen::MatrixXd::Zero(nbf, nbf));
        std::vector<Eigen::MatrixXd> Jb_tls(n_threads, Eigen::MatrixXd::Zero(nbf, nbf));
        std::vector<Eigen::MatrixXd> Kb_tls(n_threads, Eigen::MatrixXd::Zero(nbf, nbf));

        const int n_gd = static_cast<int>(gd_list.size());
        const int n_triples = n_c * n_gd;

        #pragma omp parallel for schedule(dynamic)
        for (int idx = 0; idx < n_triples; ++idx) {
            const int c_p = idx / n_gd;
            const int igd  = idx % n_gd;
            const Eigen::Vector3i gd = gd_list[igd];
            const Eigen::Vector3i cell_p = cells[c_p].index;
            const Eigen::Vector3i cell_q(cell_p[0]+gd[0], cell_p[1]+gd[1], cell_p[2]+gd[2]);
            const Eigen::Vector3i neg_p(-cell_p[0], -cell_p[1], -cell_p[2]);
            const Eigen::Vector3i neg_q(-cell_q[0], -cell_q[1], -cell_q[2]);

            const auto tid = static_cast<std::size_t>(omp_thread_index());
            auto& engine = engines[tid];
            const auto& buf = engine.results();
            auto& Jk = Jk_tls[tid]; auto& Kk = Kk_tls[tid];
            auto& Jb = Jb_tls[tid]; auto& Kb = Kb_tls[tid];

            const auto& shells_p = shells_at[c_p];

            // Find shells at cell_q and cell_{-q}.
            int q_idx = -1, nq_idx = -1;
            auto it_q = cell_pos_map.find(cell_q);
            if (it_q != cell_pos_map.end()) q_idx = it_q->second;
            auto it_nq = cell_pos_map.find(neg_q);
            if (it_nq != cell_pos_map.end()) nq_idx = it_nq->second;
            if (q_idx < 0) continue;   // sigma cell must exist
            const auto& shells_q = shells_at[q_idx];
            const std::vector<libint2::Shell>* shells_nq = (nq_idx >= 0) ? &shells_at[nq_idx] : nullptr;

            for (std::size_t s1 = 0; s1 < shells_ref.size(); ++s1) {
                const auto bf1 = shell2bf[s1]; const auto n1 = shells_ref[s1].size();
                const int a = atom_of_shell[s1];                       // mu (home)
                for (std::size_t s2 = 0; s2 < shells_home.size(); ++s2) {
                    const auto bf2 = shell2bf[s2]; const auto n2 = shells_home[s2].size();
                    const int b = atom_of_shell[s2];                    // nu (home)
                    for (std::size_t s3 = 0; s3 < shells_p.size(); ++s3) {
                        const auto bf3 = shell2bf[s3]; const auto n3 = shells_p[s3].size();
                        const int c = atom_of_shell[s3];               // lambda (cell p)
                        for (std::size_t s4 = 0; s4 < shells_q.size(); ++s4) {
                            const auto bf4 = shell2bf[s4]; const auto n4 = shells_q[s4].size();
                            const int d = atom_of_shell[s4];            // sigma (cell q)

                            // ---- ket-folded J: (mu_0 nu_0 | lambda_p sigma_q)
                            {
                                const double w_ab = w(home, a, b);
                                const double w_ac = w(cell_p, a, c);
                                const double w_bc = w(cell_p, b, c);
                                const double w_cd = w(gd, c, d);
                                const double w_Jk = w_ab * 0.5 * (w_ac + w_bc) * w_cd;
                                const Eigen::Vector3i neg_gd(-gd[0], -gd[1], -gd[2]);
                                const double w_Jb = w(home, c, d) * 0.5
                                    * (w(neg_p, c, a) + w(neg_p, d, a)) * w(neg_gd, a, b);

                                if (w_Jk != 0.0 || w_Jb != 0.0) {
                                    if (w_Jk != 0.0) {
                                        engine.compute(shells_ref[s1], shells_home[s2],
                                                       shells_p[s3], shells_q[s4]);
                                        if (const double* blk = buf[0])
                                            for (std::size_t i = 0; i < n1; ++i)
                                            for (std::size_t j = 0; j < n2; ++j)
                                            for (std::size_t k = 0; k < n3; ++k)
                                            for (std::size_t l = 0; l < n4; ++l) {
                                                const double v = blk[((i*n2+j)*n3+k)*n4+l];
                                                Jk(bf1+i, bf2+j) += P_gamma(bf3+k, bf4+l) * v * w_Jk;
                                            }
                                    }
                                    if (w_Jb != 0.0 && shells_nq) {
                                        const auto& snq = *shells_nq;
                                        engine.compute(shells_ref[s3], shells_ref[s4],
                                                       snq[s1], snq[s2]);
                                        if (const double* blk = buf[0])
                                            for (std::size_t i = 0; i < n3; ++i)
                                            for (std::size_t j = 0; j < n4; ++j)
                                            for (std::size_t k = 0; k < n1; ++k)
                                            for (std::size_t l = 0; l < n2; ++l) {
                                                const double v = blk[((i*n4+j)*n1+k)*n2+l];
                                                Jb(bf1+k, bf2+l) += P_gamma(bf4+i, bf3+j) * v * w_Jb;
                                            }
                                    }
                                }
                            }

                            // ---- ket-folded K: (mu_0 lambda_p | nu_0 sigma_q)
                            {
                                const double wk_ab = w(cell_p, a, c);
                                const double wk_ac = w(home, a, b);
                                const double wk_bc = w(neg_p, c, b);
                                const double wk_cd = w(cell_q, b, d);  // No: w(gd, b, d)?
                                // Wait — K has: bra=(mu_0, lambda_p), ket=(nu_0, sigma_q)
                                // ket-pair displacement = q-0 = q. So omega_cd = W[q][b,d]? No.
                                // Actually: nu at home (cell 0), sigma at cell q. 
                                // The WSSC weight omega_cd is w(cell_q, b, d) = W[q][b,d].
                                // But we need w(cell_q)[b,d], not w(gd)[b,d].
                                // gd = cell_q - cell_p. cell_q is the actual cell of sigma.
                                // w(cell_q, b, d) uses the ACTUAL cell of sigma relative to nu at home.
                                const double w_Kk_cd = w(cell_q, b, d);
                                const double w_Kk = wk_ab * 0.5 * (wk_ac + wk_bc) * w_Kk_cd;
                                const double w_Kb = w(home, b, d) * 0.5
                                    * (w(home, b, a) + w(home, d, a)) * w(neg_p, a, c);

                                if (w_Kk != 0.0 || w_Kb != 0.0) {
                                    if (w_Kk != 0.0) {
                                        engine.compute(shells_ref[s1], shells_p[s3],
                                                       shells_home[s2], shells_q[s4]);
                                        if (const double* blk = buf[0])
                                            for (std::size_t i = 0; i < n1; ++i)
                                            for (std::size_t j = 0; j < n2; ++j)
                                            for (std::size_t k = 0; k < n3; ++k)
                                            for (std::size_t l = 0; l < n4; ++l) {
                                                const double v = blk[((i*n3+k)*n2+j)*n4+l];
                                                Kk(bf1+i, bf2+j) += P_gamma(bf3+k, bf4+l) * v * w_Kk;
                                            }
                                    }
                                    if (w_Kb != 0.0 && shells_nq) {
                                        const auto& snq = *shells_nq;
                                        engine.compute(shells_home[s2], shells_q[s4],
                                                       snq[s1], shells_ref[s3]);
                                        if (const double* blk = buf[0])
                                            for (std::size_t i = 0; i < n2; ++i)
                                            for (std::size_t j = 0; j < n4; ++j)
                                            for (std::size_t k = 0; k < n1; ++k)
                                            for (std::size_t l = 0; l < n3; ++l) {
                                                const double v = blk[((i*n4+j)*n1+k)*n3+l];
                                                Kb(bf1+k, bf2+i) += P_gamma(bf4+j, bf3+l) * v * w_Kb;
                                            }
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }

        // Reduce and symmetrise
        Eigen::MatrixXd Jkm = Eigen::MatrixXd::Zero(nbf, nbf);
        Eigen::MatrixXd Kkm = Eigen::MatrixXd::Zero(nbf, nbf);
        Eigen::MatrixXd Jbm = Eigen::MatrixXd::Zero(nbf, nbf);
        Eigen::MatrixXd Kbm = Eigen::MatrixXd::Zero(nbf, nbf);
        for (const auto& m : Jk_tls) Jkm += m;
        for (const auto& m : Kk_tls) Kkm += m;
        for (const auto& m : Jb_tls) Jbm += m;
        for (const auto& m : Kb_tls) Kbm += m;
        return JKMatrices{0.5*(Jkm+Jbm), 0.5*(Kkm+Kbm)};
    }

    if (method == "bra_home_full") {
        // ================================================================
        // bra_home_full: build effective ERI tensor V, symmetrise, contract.
        // Matches the padded route's ccm_eri exactly: V is built with
        // bra at home, ket imaged (c_p + gd), then bra-ket symmetrised
        // as V_sym = 0.5*(V + V^T), then J = contract(P, V_sym).
        // O(nbf^4) memory — small-system reference implementation.
        // ================================================================

        const int nbf2 = nbf * nbf;
        std::vector<Eigen::MatrixXd> Vj_tls(n_threads, Eigen::MatrixXd::Zero(nbf2, nbf2));

        std::vector<Eigen::Vector3i> gd_list(weight_cells.begin(), weight_cells.end());
        const int n_gd = static_cast<int>(gd_list.size());
        const int n_triples = n_c * n_gd;

        #pragma omp parallel for schedule(dynamic)
        for (int idx = 0; idx < n_triples; ++idx) {
            const int c_p = idx / n_gd;
            const int igd  = idx % n_gd;
            const Eigen::Vector3i gd = gd_list[igd];
            const Eigen::Vector3i cell_p = cells[c_p].index;
            const Eigen::Vector3i cell_q(cell_p[0]+gd[0], cell_p[1]+gd[1], cell_p[2]+gd[2]);

            const auto tid = static_cast<std::size_t>(omp_thread_index());
            auto& engine = engines[tid];
            const auto& buf = engine.results();
            auto& Vj = Vj_tls[tid];
            const auto& shells_p = shells_at[c_p];

            int q_idx = -1;
            auto it_q = cell_pos_map.find(cell_q);
            if (it_q != cell_pos_map.end()) q_idx = it_q->second;
            if (q_idx < 0) continue;
            const auto& shells_q = shells_at[q_idx];

            // Bra (mu_0 nu_0) is the home cell: BOTH bra indices must use the
            // same (reference) shell set. Mixing shells_ref[s1] with
            // shells_home[s2] makes mu and nu come from distinct shell objects
            // and breaks the mu<->nu symmetry of the effective tensor (the
            // wrap/long-range elements come out asymmetric). Use shells_ref for
            // both — verified against the padded ccm_eri reference.
            for (std::size_t s1 = 0; s1 < shells_ref.size(); ++s1) {
                const auto bf1 = shell2bf[s1]; const auto n1 = shells_ref[s1].size();
                const int a = atom_of_shell[s1];
                for (std::size_t s2 = 0; s2 < shells_ref.size(); ++s2) {
                    const auto bf2 = shell2bf[s2]; const auto n2 = shells_ref[s2].size();
                    const int b = atom_of_shell[s2];
                    for (std::size_t s3 = 0; s3 < shells_p.size(); ++s3) {
                        const auto bf3 = shell2bf[s3]; const auto n3 = shells_p[s3].size();
                        const int c = atom_of_shell[s3];
                        for (std::size_t s4 = 0; s4 < shells_q.size(); ++s4) {
                            const auto bf4 = shell2bf[s4]; const auto n4 = shells_q[s4].size();
                            const int d = atom_of_shell[s4];

                            // eq-18 four-center weight, bra at home (matches padded
                            // ccm_eri): w_J = omega_ab(0) * 0.5(omega_ac+omega_bc)(p)
                            //                * omega_cd(gd).
                            const double w_ab = w(home, a, b);
                            if (w_ab == 0.0) continue;
                            const double w_J = w_ab * 0.5
                                * (w(cell_p, a, c) + w(cell_p, b, c)) * w(gd, c, d);
                            if (w_J == 0.0) continue;
                            // One effective ERI tensor V[mu nu, la si] from the
                            // home-bra / ket-imaged integral (mu_0 nu_0 | lambda_p sigma_q).
                            // J and K are both contracted from its bra-ket
                            // symmetrisation below (no separate K tensor).
                            engine.compute(shells_ref[s1], shells_ref[s2],
                                           shells_p[s3], shells_q[s4]);
                            if (const double* blk = buf[0])
                                for (std::size_t i = 0; i < n1; ++i)
                                for (std::size_t j = 0; j < n2; ++j)
                                for (std::size_t k = 0; k < n3; ++k)
                                for (std::size_t l = 0; l < n4; ++l) {
                                    const double v = blk[((i*n2+j)*n3+k)*n4+l];
                                    const int mu = bf1 + i, nu = bf2 + j;
                                    const int la = bf3 + k, si = bf4 + l;
                                    Vj(mu*nbf+nu, la*nbf+si) += v * w_J;
                                }
                        }
                    }
                }
            }
        }

        // Reduce across threads.
        Eigen::MatrixXd V_full = Eigen::MatrixXd::Zero(nbf2, nbf2);
        for (const auto& m : Vj_tls) V_full += m;

        // Bra-ket symmetrise: V_sym = 0.5*(V + V^T) — the (mu nu|la si) <-> (la si|mu nu)
        // swap, exactly the padded ccm_eri's 0.5*(eff + eff^{(cd|ab)}). NB: assign
        // to a SEPARATE matrix — `V_full = 0.5*(V_full + V_full.transpose())`
        // aliases in Eigen (the transpose reads entries already overwritten) and
        // silently yields an ASYMMETRIC tensor, breaking the SCF.
        const Eigen::MatrixXd V_sym = 0.5 * (V_full + V_full.transpose());

        // Contract BOTH J and K from the SAME symmetrised tensor, matching the
        // padded run_ccm_rhf contractions J=einsum("mnrs,rs"), K=einsum("msrn,rs"):
        //   J[mu,nu] = sum_{la,si} P[la,si] * V_sym[mu*n+nu, la*n+si]
        //   K[mu,nu] = sum_{la,si} P[la,si] * V_sym[mu*n+si, la*n+nu]
        Eigen::MatrixXd Jm = Eigen::MatrixXd::Zero(nbf, nbf);
        Eigen::MatrixXd Km = Eigen::MatrixXd::Zero(nbf, nbf);
        for (int mu = 0; mu < nbf; ++mu) {
            for (int nu = 0; nu < nbf; ++nu) {
                double jsum = 0.0, ksum = 0.0;
                for (int la = 0; la < nbf; ++la) {
                    for (int si = 0; si < nbf; ++si) {
                        const double p = P_gamma(la, si);
                        if (p != 0.0) {
                            jsum += p * V_sym(mu*nbf+nu, la*nbf+si);
                            ksum += p * V_sym(mu*nbf+si, la*nbf+nu);
                        }
                    }
                }
                Jm(mu, nu) = jsum;
                Km(mu, nu) = ksum;
            }
        }
        // The eq-18 CCM frame is slightly non-Hermitian (~1e-3 — the residual
        // cyclic-invariance breaking, present in the padded reference too).
        // Hermitise J and K so the Fock is symmetric and the SCF gradient
        // converges (mirrors run_ccm_rhf's F = 0.5*(F + F^T)). Separate
        // destinations: A = 0.5*(A + A.transpose()) aliases in Eigen.
        const Eigen::MatrixXd Jh = 0.5 * (Jm + Jm.transpose());
        const Eigen::MatrixXd Kh = 0.5 * (Km + Km.transpose());
        return JKMatrices{Jh, Kh};
    }

    if (method == "bra_home_full-direct") {
        // ================================================================
        // bra_home_full-direct: the INTEGRAL-DIRECT form of "bra_home_full"
        // (eq-18 weight; Python method "union12"). Same triple loop, same
        // w_J, same bra-ket symmetrisation as the full-tensor branch above,
        // but folds each weighted quartet straight into J/K rather than the
        // O(nbf^4) tensor V. See the aiccm2026dev-a-direct branch below for the
        // fold algebra; the full "bra_home_full" branch is preserved above as
        // the comparison reference. Result agrees to ~1e-12 (summation reorder).
        // ================================================================

        std::vector<Eigen::MatrixXd> J_tls(n_threads, Eigen::MatrixXd::Zero(nbf, nbf));
        std::vector<Eigen::MatrixXd> K_tls(n_threads, Eigen::MatrixXd::Zero(nbf, nbf));

        std::vector<Eigen::Vector3i> gd_list(weight_cells.begin(), weight_cells.end());
        const int n_gd = static_cast<int>(gd_list.size());
        const int n_triples = n_c * n_gd;

        #pragma omp parallel for schedule(dynamic)
        for (int idx = 0; idx < n_triples; ++idx) {
            const int c_p = idx / n_gd;
            const int igd  = idx % n_gd;
            const Eigen::Vector3i gd = gd_list[igd];
            const Eigen::Vector3i cell_p = cells[c_p].index;
            const Eigen::Vector3i cell_q(cell_p[0]+gd[0], cell_p[1]+gd[1], cell_p[2]+gd[2]);

            const auto tid = static_cast<std::size_t>(omp_thread_index());
            auto& engine = engines[tid];
            const auto& buf = engine.results();
            auto& Jloc = J_tls[tid];
            auto& Kloc = K_tls[tid];
            const auto& shells_p = shells_at[c_p];

            int q_idx = -1;
            auto it_q = cell_pos_map.find(cell_q);
            if (it_q != cell_pos_map.end()) q_idx = it_q->second;
            if (q_idx < 0) continue;
            const auto& shells_q = shells_at[q_idx];

            // Opt-in Schwarz screening (active only when screen). Ket pair
            // (lambda_p, sigma_q) separation = gd -> Q at idx(gd); bra is home ->
            // Q at c_zero_idx. eq-18 weight w_J in [0,1], so w_J*Q_bra*Q_ket*D_max
            // bounds the J/K contribution. See the aiccm2026dev-a-direct note.
            int gd_idx = -1;
            if (screen) { auto itg = cell_index_map.find(gd); if (itg != cell_index_map.end()) gd_idx = itg->second; }
            const bool do_screen = screen && gd_idx >= 0 && c_zero_idx >= 0;

            for (std::size_t s1 = 0; s1 < shells_ref.size(); ++s1) {
                const auto bf1 = shell2bf[s1]; const auto n1 = shells_ref[s1].size();
                const int a = atom_of_shell[s1];
                for (std::size_t s2 = 0; s2 < shells_ref.size(); ++s2) {
                    const auto bf2 = shell2bf[s2]; const auto n2 = shells_ref[s2].size();
                    const int b = atom_of_shell[s2];
                    const double w_ab0 = w(home, a, b);
                    if (w_ab0 == 0.0) continue;
                    const double q_bra = do_screen ? Q[c_zero_idx][s1*nshells + s2] : 0.0;
                    if (do_screen && q_bra * Q_max[gd_idx] * D_max < schwarz_thr) continue;
                    for (std::size_t s3 = 0; s3 < shells_p.size(); ++s3) {
                        const auto bf3 = shell2bf[s3]; const auto n3 = shells_p[s3].size();
                        const int c = atom_of_shell[s3];
                        for (std::size_t s4 = 0; s4 < shells_q.size(); ++s4) {
                            const auto bf4 = shell2bf[s4]; const auto n4 = shells_q[s4].size();
                            const int d = atom_of_shell[s4];

                            // eq-18 four-center weight, bra at home (matches the
                            // full bra_home_full branch above).
                            const double w_ab = w(home, a, b);
                            if (w_ab == 0.0) continue;
                            const double w_J = w_ab * 0.5
                                * (w(cell_p, a, c) + w(cell_p, b, c)) * w(gd, c, d);
                            if (w_J == 0.0) continue;
                            // Per-quartet Schwarz bound.
                            if (do_screen &&
                                w_J * q_bra * Q[gd_idx][s3*nshells + s4] * D_max < schwarz_thr)
                                continue;
                            engine.compute(shells_ref[s1], shells_ref[s2],
                                           shells_p[s3], shells_q[s4]);
                            if (const double* blk = buf[0])
                                for (std::size_t i = 0; i < n1; ++i)
                                for (std::size_t j = 0; j < n2; ++j)
                                for (std::size_t k = 0; k < n3; ++k)
                                for (std::size_t l = 0; l < n4; ++l) {
                                    const double t = blk[((i*n2+j)*n3+k)*n4+l] * w_J;
                                    if (t == 0.0) continue;
                                    const int mu = bf1 + i, nu = bf2 + j;
                                    const int la = bf3 + k, si = bf4 + l;
                                    Jloc(mu, nu) += 0.5 * t * P_gamma(la, si);
                                    Jloc(la, si) += 0.5 * t * P_gamma(mu, nu);
                                    Kloc(mu, si) += 0.5 * t * P_gamma(la, nu);
                                    Kloc(la, nu) += 0.5 * t * P_gamma(mu, si);
                                }
                        }
                    }
                }
            }
        }

        Eigen::MatrixXd Jm = Eigen::MatrixXd::Zero(nbf, nbf);
        Eigen::MatrixXd Km = Eigen::MatrixXd::Zero(nbf, nbf);
        for (const auto& m : J_tls) Jm += m;
        for (const auto& m : K_tls) Km += m;
        // Same final Hermitisation as the full bra_home_full branch.
        const Eigen::MatrixXd Jh = 0.5 * (Jm + Jm.transpose());
        const Eigen::MatrixXd Kh = 0.5 * (Km + Km.transpose());
        return JKMatrices{Jh, Kh};
    }

    if (method == "aiccm2026dev-a" || method == "aiccmdev") {  // "aiccmdev" = deprecated alias
        // ================================================================
        // aiccm2026dev-a: the symmetric Born-von Karman-torus four-center
        // (AICCM_ALGORITHM.md §13; Python padded.ccm_eri_symmetric). Two
        // changes vs bra_home_full (eq 18), both needed for exact 8-fold
        // permutation symmetry on ANY lattice:
        //   * symmetric bridge  1/4(w_ac + w_bc + w_ad + w_bd)  -- treats the
        //     ket functions rho, sigma identically (eq 18 uses 1/2(w_ac+w_bc),
        //     singling out the ket anchor rho);
        //   * independent minimum-image fold -- rho at g_c and sigma at g_e are
        //     EACH the min image of the home bra; the ket-pair weight w_rho_sigma
        //     is taken at g_e - g_c (eq 18 chains sigma to rho at g_c+g_d).
        // V is then bra-ket symmetrised and contracted exactly as bra_home_full.
        // O(nbf^4) memory -- small-system reference (matches the padded route).
        // ================================================================

        const int nbf2 = nbf * nbf;
        std::vector<Eigen::MatrixXd> Vj_tls(n_threads, Eigen::MatrixXd::Zero(nbf2, nbf2));

        const std::vector<Eigen::Vector3i> wcells(weight_cells.begin(), weight_cells.end());
        const int n_w = static_cast<int>(wcells.size());
        const int n_pairs = n_w * n_w;

        #pragma omp parallel for schedule(dynamic)
        for (int idx = 0; idx < n_pairs; ++idx) {
            const Eigen::Vector3i gc = wcells[idx / n_w];          // rho cell
            const Eigen::Vector3i ge = wcells[idx % n_w];          // sigma cell
            const Eigen::Vector3i grel(ge[0]-gc[0], ge[1]-gc[1], ge[2]-gc[2]);
            // rho,sigma must be a minimum-image pair (w_rho_sigma at g_e-g_c).
            if (wlookup.find(grel) == wlookup.end()) continue;

            int gc_idx = -1, ge_idx = -1;
            auto it_c = cell_pos_map.find(gc); if (it_c != cell_pos_map.end()) gc_idx = it_c->second;
            auto it_e = cell_pos_map.find(ge); if (it_e != cell_pos_map.end()) ge_idx = it_e->second;
            if (gc_idx < 0 || ge_idx < 0) continue;

            const auto tid = static_cast<std::size_t>(omp_thread_index());
            auto& engine = engines[tid];
            const auto& buf = engine.results();
            auto& Vj = Vj_tls[tid];
            const auto& shells_c = shells_at[gc_idx];
            const auto& shells_e = shells_at[ge_idx];

            // Bra (mu_0 nu_0) at home: both bra indices use shells_ref (see the
            // bra_home_full note on preserving the mu<->nu symmetry).
            for (std::size_t s1 = 0; s1 < shells_ref.size(); ++s1) {
                const auto bf1 = shell2bf[s1]; const auto n1 = shells_ref[s1].size();
                const int a = atom_of_shell[s1];
                for (std::size_t s2 = 0; s2 < shells_ref.size(); ++s2) {
                    const auto bf2 = shell2bf[s2]; const auto n2 = shells_ref[s2].size();
                    const int b = atom_of_shell[s2];
                    const double w_ab = w(home, a, b);
                    if (w_ab == 0.0) continue;
                    for (std::size_t s3 = 0; s3 < shells_c.size(); ++s3) {
                        const auto bf3 = shell2bf[s3]; const auto n3 = shells_c[s3].size();
                        const int c = atom_of_shell[s3];
                        for (std::size_t s4 = 0; s4 < shells_e.size(); ++s4) {
                            const auto bf4 = shell2bf[s4]; const auto n4 = shells_e[s4].size();
                            const int d = atom_of_shell[s4];

                            const double w_cd = w(grel, c, d);
                            if (w_cd == 0.0) continue;
                            const double bridge = 0.25 * (w(gc, a, c) + w(gc, b, c)
                                                          + w(ge, a, d) + w(ge, b, d));
                            const double w4 = w_ab * bridge * w_cd;
                            if (w4 == 0.0) continue;

                            // (mu_0 nu_0 | lambda_{g_c} sigma_{g_e})
                            engine.compute(shells_ref[s1], shells_ref[s2],
                                           shells_c[s3], shells_e[s4]);
                            if (const double* blk = buf[0])
                                for (std::size_t i = 0; i < n1; ++i)
                                for (std::size_t j = 0; j < n2; ++j)
                                for (std::size_t k = 0; k < n3; ++k)
                                for (std::size_t l = 0; l < n4; ++l) {
                                    const double v = blk[((i*n2+j)*n3+k)*n4+l];
                                    const int mu = bf1 + i, nu = bf2 + j;
                                    const int la = bf3 + k, si = bf4 + l;
                                    Vj(mu*nbf+nu, la*nbf+si) += v * w4;
                                }
                        }
                    }
                }
            }
        }

        Eigen::MatrixXd V_full = Eigen::MatrixXd::Zero(nbf2, nbf2);
        for (const auto& m : Vj_tls) V_full += m;
        // Bra-ket symmetrise (the remaining generator; mu<->nu and rho<->sigma
        // already hold by construction). Separate destination -- aliasing note
        // in bra_home_full applies.
        const Eigen::MatrixXd V_sym = 0.5 * (V_full + V_full.transpose());

        Eigen::MatrixXd Jm = Eigen::MatrixXd::Zero(nbf, nbf);
        Eigen::MatrixXd Km = Eigen::MatrixXd::Zero(nbf, nbf);
        for (int mu = 0; mu < nbf; ++mu) {
            for (int nu = 0; nu < nbf; ++nu) {
                double jsum = 0.0, ksum = 0.0;
                for (int la = 0; la < nbf; ++la) {
                    for (int si = 0; si < nbf; ++si) {
                        const double p = P_gamma(la, si);
                        if (p != 0.0) {
                            jsum += p * V_sym(mu*nbf+nu, la*nbf+si);
                            ksum += p * V_sym(mu*nbf+si, la*nbf+nu);
                        }
                    }
                }
                Jm(mu, nu) = jsum;
                Km(mu, nu) = ksum;
            }
        }
        // The symmetric tensor already gives Hermitian J/K; this is a no-op
        // safeguard (separate destinations -- Eigen aliasing).
        const Eigen::MatrixXd Jh = 0.5 * (Jm + Jm.transpose());
        const Eigen::MatrixXd Kh = 0.5 * (Km + Km.transpose());
        return JKMatrices{Jh, Kh};
    }

    if (method == "aiccm2026dev-a-direct" || method == "aiccmdev-direct") {
        // ================================================================
        // aiccm2026dev-a-direct: the INTEGRAL-DIRECT form of "aiccm2026dev-a"
        // (Phase 3b). Identical quartet loop, weights, and bra-ket
        // symmetrisation as the full-tensor branch above, but each weighted
        // quartet block is folded straight into J and K instead of accumulated
        // into the O(nbf^4) effective tensor V. Peak memory drops from
        // (n_threads+2)*nbf^4 (Vj_tls + V_full + V_sym) to n_threads*2*nbf^2
        // (thread-local J/K) -- this is what makes 3-D cells at production basis
        // fit in RAM. The dense "aiccm2026dev-a" branch is PRESERVED above as the
        // small-cluster comparison reference (and the byte-for-byte gate target).
        //
        // Exactness of the fold: with V_sym = 0.5*(V + V^T) and the full-branch
        // contractions  J[mn] = sum_{ls} P[ls] V_sym[mn,ls],
        //                K[mn] = sum_{ls} P[ls] V_sym[ms,ln],
        // a single block  t = (mu nu | la si) * w4  contributes
        //   J[mu,nu] += 0.5 t P[la,si]      J[la,si] += 0.5 t P[mu,nu]
        //   K[mu,si] += 0.5 t P[la,nu]      K[la,nu] += 0.5 t P[mu,si]
        // (the paired terms are the V and V^T halves of V_sym; the home-bra /
        // imaged-ket loop never visits the transpose itself). A reduction +
        // final Hermitisation match the full branch; result agrees to ~1e-12
        // (summation reorder), NOT bit-for-bit -- hence opt-in, not a silent
        // replacement of the dense reference.
        // ================================================================

        std::vector<Eigen::MatrixXd> J_tls(n_threads, Eigen::MatrixXd::Zero(nbf, nbf));
        std::vector<Eigen::MatrixXd> K_tls(n_threads, Eigen::MatrixXd::Zero(nbf, nbf));

        const std::vector<Eigen::Vector3i> wcells(weight_cells.begin(), weight_cells.end());
        const int n_w = static_cast<int>(wcells.size());
        const int n_pairs = n_w * n_w;

        #pragma omp parallel for schedule(dynamic)
        for (int idx = 0; idx < n_pairs; ++idx) {
            const Eigen::Vector3i gc = wcells[idx / n_w];          // rho cell
            const Eigen::Vector3i ge = wcells[idx % n_w];          // sigma cell
            const Eigen::Vector3i grel(ge[0]-gc[0], ge[1]-gc[1], ge[2]-gc[2]);
            // rho,sigma must be a minimum-image pair (w_rho_sigma at g_e-g_c).
            if (wlookup.find(grel) == wlookup.end()) continue;

            int gc_idx = -1, ge_idx = -1;
            auto it_c = cell_pos_map.find(gc); if (it_c != cell_pos_map.end()) gc_idx = it_c->second;
            auto it_e = cell_pos_map.find(ge); if (it_e != cell_pos_map.end()) ge_idx = it_e->second;
            if (gc_idx < 0 || ge_idx < 0) continue;

            const auto tid = static_cast<std::size_t>(omp_thread_index());
            auto& engine = engines[tid];
            const auto& buf = engine.results();
            auto& Jloc = J_tls[tid];
            auto& Kloc = K_tls[tid];
            const auto& shells_c = shells_at[gc_idx];
            const auto& shells_e = shells_at[ge_idx];

            // Opt-in Schwarz screening (active only when screen, i.e.
            // opts.schwarz_threshold > 0; off by default keeps the kernel exact).
            // The ket pair (lambda_{g_c}, sigma_{g_e}) Schwarz factor is
            // translation-invariant -> Q at separation grel = g_e - g_c; the bra
            // is home-home -> Q at c_zero_idx. WSSC weights are in [0,1] so
            // |w4| <= 1, making w4 * Q_bra * Q_ket * D_max a valid upper bound on
            // the |w4 (mu nu|la si) P[la,si]| J/K contribution.
            int grel_idx = -1;
            if (screen) { auto itg = cell_index_map.find(grel); if (itg != cell_index_map.end()) grel_idx = itg->second; }
            const bool do_screen = screen && grel_idx >= 0 && c_zero_idx >= 0;

            // Bra (mu_0 nu_0) at home: both bra indices use shells_ref (see the
            // bra_home_full note on preserving the mu<->nu symmetry).
            for (std::size_t s1 = 0; s1 < shells_ref.size(); ++s1) {
                const auto bf1 = shell2bf[s1]; const auto n1 = shells_ref[s1].size();
                const int a = atom_of_shell[s1];
                for (std::size_t s2 = 0; s2 < shells_ref.size(); ++s2) {
                    const auto bf2 = shell2bf[s2]; const auto n2 = shells_ref[s2].size();
                    const int b = atom_of_shell[s2];
                    const double w_ab = w(home, a, b);
                    if (w_ab == 0.0) continue;
                    const double q_bra = do_screen ? Q[c_zero_idx][s1*nshells + s2] : 0.0;
                    // Bra-level early-out: even the largest ket pair can't survive.
                    if (do_screen && q_bra * Q_max[grel_idx] * D_max < schwarz_thr) continue;
                    for (std::size_t s3 = 0; s3 < shells_c.size(); ++s3) {
                        const auto bf3 = shell2bf[s3]; const auto n3 = shells_c[s3].size();
                        const int c = atom_of_shell[s3];
                        for (std::size_t s4 = 0; s4 < shells_e.size(); ++s4) {
                            const auto bf4 = shell2bf[s4]; const auto n4 = shells_e[s4].size();
                            const int d = atom_of_shell[s4];

                            const double w_cd = w(grel, c, d);
                            if (w_cd == 0.0) continue;
                            const double bridge = 0.25 * (w(gc, a, c) + w(gc, b, c)
                                                          + w(ge, a, d) + w(ge, b, d));
                            const double w4 = w_ab * bridge * w_cd;
                            if (w4 == 0.0) continue;
                            // Per-quartet Schwarz bound.
                            if (do_screen &&
                                w4 * q_bra * Q[grel_idx][s3*nshells + s4] * D_max < schwarz_thr)
                                continue;

                            // (mu_0 nu_0 | lambda_{g_c} sigma_{g_e})
                            engine.compute(shells_ref[s1], shells_ref[s2],
                                           shells_c[s3], shells_e[s4]);
                            if (const double* blk = buf[0])
                                for (std::size_t i = 0; i < n1; ++i)
                                for (std::size_t j = 0; j < n2; ++j)
                                for (std::size_t k = 0; k < n3; ++k)
                                for (std::size_t l = 0; l < n4; ++l) {
                                    const double t = blk[((i*n2+j)*n3+k)*n4+l] * w4;
                                    if (t == 0.0) continue;
                                    const int mu = bf1 + i, nu = bf2 + j;
                                    const int la = bf3 + k, si = bf4 + l;
                                    // V and V^T halves of V_sym, folded into J:
                                    Jloc(mu, nu) += 0.5 * t * P_gamma(la, si);
                                    Jloc(la, si) += 0.5 * t * P_gamma(mu, nu);
                                    // ...and into K:
                                    Kloc(mu, si) += 0.5 * t * P_gamma(la, nu);
                                    Kloc(la, nu) += 0.5 * t * P_gamma(mu, si);
                                }
                        }
                    }
                }
            }
        }

        Eigen::MatrixXd Jm = Eigen::MatrixXd::Zero(nbf, nbf);
        Eigen::MatrixXd Km = Eigen::MatrixXd::Zero(nbf, nbf);
        for (const auto& m : J_tls) Jm += m;
        for (const auto& m : K_tls) Km += m;
        // Final Hermitisation matches the full branch (separate destinations --
        // Eigen aliasing note above).
        const Eigen::MatrixXd Jh = 0.5 * (Jm + Jm.transpose());
        const Eigen::MatrixXd Kh = 0.5 * (Km + Km.transpose());
        return JKMatrices{Jh, Kh};
    }


    // ================================================================
    // union12: nu imaged, K symmetrised (AICCM reference frame)
    // ================================================================
    const int n_pairs = n_c * n_c;
    std::vector<Eigen::MatrixXd> J_tls(n_threads, Eigen::MatrixXd::Zero(nbf, nbf));
    std::vector<Eigen::MatrixXd> K_tls(n_threads, Eigen::MatrixXd::Zero(nbf, nbf));

    #pragma omp parallel for schedule(dynamic)
    for (int idx = 0; idx < n_pairs; ++idx) {
        const int c_g = idx / n_c;
        const int c_p = idx % n_c;
        const auto tid = static_cast<std::size_t>(omp_thread_index());
        auto& engine = engines[tid];
        const auto& buf = engine.results();
        auto& Jloc = J_tls[tid];
        auto& Kloc = K_tls[tid];
        const auto& shells_g = shells_at[c_g];
        const auto& shells_p = shells_at[c_p];
        const Eigen::Vector3i cell_g = cells[c_g].index;
        const Eigen::Vector3i cell_p = cells[c_p].index;
        const Eigen::Vector3i pg(cell_p[0]-cell_g[0], cell_p[1]-cell_g[1], cell_p[2]-cell_g[2]);
        const Eigen::Vector3i gp(cell_g[0]-cell_p[0], cell_g[1]-cell_p[1], cell_g[2]-cell_p[2]);

        int c_pg_idx = -1;
        if (screen) {
            auto it = cell_index_map.find(cells[c_p].index - cells[c_g].index);
            if (it != cell_index_map.end()) c_pg_idx = it->second;
        }
        if (screen) {
            const bool jp = (Q_max[c_g] * ((c_zero_idx>=0)?Q_max[c_zero_idx]:0.0) * D_max >= schwarz_thr);
            const double qpg = (c_pg_idx>=0) ? Q_max[c_pg_idx] : 0.0;
            if (!jp && !(Q_max[c_p]*qpg*D_max >= schwarz_thr)) continue;
        }

        for (std::size_t s1 = 0; s1 < shells_ref.size(); ++s1) {
            const auto bf1 = shell2bf[s1]; const auto n1 = shells_ref[s1].size();
            const int a = atom_of_shell[s1];
            for (std::size_t s2 = 0; s2 < shells_g.size(); ++s2) {
                const auto bf2 = shell2bf[s2]; const auto n2 = shells_g[s2].size();
                const int b = atom_of_shell[s2];
                for (std::size_t s3 = 0; s3 < shells_p.size(); ++s3) {
                    const auto bf3 = shell2bf[s3]; const auto n3 = shells_p[s3].size();
                    const int c = atom_of_shell[s3];
                    for (std::size_t s4 = 0; s4 < shells_p.size(); ++s4) {
                        const auto bf4 = shell2bf[s4]; const auto n4 = shells_p[s4].size();
                        const int d = atom_of_shell[s4];

                        const double w_ab = w(cell_g, a, b);
                        if (w_ab == 0.0) continue;
                        const double w_ac = w(cell_p, a, c);
                        const double w_bc = w(pg, b, c);
                        const double w_cd = w(home, c, d);
                        const double w_J = w_ab * 0.5*(w_ac+w_bc) * w_cd;

                        if (w_J != 0.0) {
                            engine.compute(shells_ref[s1], shells_g[s2], shells_p[s3], shells_p[s4]);
                            if (const double* blk = buf[0])
                                for (std::size_t i = 0; i < n1; ++i)
                                for (std::size_t j = 0; j < n2; ++j)
                                for (std::size_t k = 0; k < n3; ++k)
                                for (std::size_t l = 0; l < n4; ++l) {
                                    const double v = blk[((i*n2+j)*n3+k)*n4+l];
                                    Jloc(bf1+i, bf2+j) += P_gamma(bf3+k, bf4+l) * v * w_J;
                                }
                        }

                        // K1: V[mu,lambda,nu,sigma] -> (mu_0 lambda_p | nu_g sigma_p)
                        const double wk1_ab = w(cell_p, a, c);
                        if (wk1_ab != 0.0) {
                            const double wk1_ac = w(cell_g, a, b);
                            const double wk1_bc = w(gp, c, b);
                            const double wk1_cd = w(home, b, d);
                            const double w_K1 = wk1_ab * 0.5*(wk1_ac+wk1_bc) * wk1_cd;
                            if (w_K1 != 0.0) {
                                engine.compute(shells_ref[s1], shells_p[s3], shells_g[s2], shells_p[s4]);
                                if (const double* blk = buf[0])
                                    for (std::size_t i = 0; i < n1; ++i)
                                    for (std::size_t j = 0; j < n2; ++j)
                                    for (std::size_t k = 0; k < n3; ++k)
                                    for (std::size_t l = 0; l < n4; ++l) {
                                        const double v = blk[((i*n3+k)*n2+j)*n4+l];
                                        Kloc(bf1+i, bf2+j) += P_gamma(bf3+k, bf4+l) * v * w_K1;
                                    }
                            }
                        }
                        // K2: V[mu,sigma,nu,lambda] -> (mu_0 sigma_p | nu_g lambda_p)
                        const double wk2_ab = w(cell_p, a, d);
                        if (wk2_ab != 0.0) {
                            const double wk2_ac = w(cell_g, a, b);
                            const double wk2_bc = w(gp, d, b);
                            const double wk2_cd = w(home, b, c);
                            const double w_K2 = wk2_ab * 0.5*(wk2_ac+wk2_bc) * wk2_cd;
                            if (w_K2 != 0.0) {
                                engine.compute(shells_ref[s1], shells_p[s4], shells_g[s2], shells_p[s3]);
                                if (const double* blk = buf[0])
                                    for (std::size_t i = 0; i < n1; ++i)
                                    for (std::size_t j = 0; j < n2; ++j)
                                    for (std::size_t k = 0; k < n4; ++k)
                                    for (std::size_t l = 0; l < n3; ++l) {
                                        const double v = blk[((i*n4+k)*n2+j)*n3+l];
                                        Kloc(bf1+i, bf2+j) += P_gamma(bf4+k, bf3+l) * v * w_K2;
                                    }
                            }
                        }
                    }
                }
            }
        }
    }

    Eigen::MatrixXd Jm = Eigen::MatrixXd::Zero(nbf, nbf);
    Eigen::MatrixXd Km = Eigen::MatrixXd::Zero(nbf, nbf);
    for (const auto& m : J_tls) Jm += m;
    for (const auto& m : K_tls) Km += m;
    Km *= 0.5;
    return JKMatrices{Jm, Km};
}
}  // namespace vibeqc
