#include "vibeqc/lattice_integrals.hpp"

#include "vibeqc/ao_eval.hpp"
#include "vibeqc/init.hpp"
#include "vibeqc/thread_pool.hpp"

#include <libint2/engine.h>
#include <array>
#include <stdexcept>
#include <utility>
#include <vector>

namespace vibeqc {

namespace {

// Clone a shell vector and translate every origin by dr (bohr). The
// contraction coefficients / primitive exponents are unchanged — only
// the origin moves, which is what a lattice translation does.
std::vector<libint2::Shell> shift_shells(
    const libint2::BasisSet& shells, const Eigen::Vector3d& dr) {
    std::vector<libint2::Shell> out(shells.begin(), shells.end());
    for (auto& s : out) {
        s.O[0] += dr[0];
        s.O[1] += dr[1];
        s.O[2] += dr[2];
    }
    return out;
}

// Core driver: for every lattice cell g, compute ⟨ χ_μ(0) | Op | χ_ν(g) ⟩.
// The Op is a libint 1-body operator (overlap, kinetic, or nuclear). For
// nuclear attraction, caller supplies the already-lattice-summed point
// charge list via ``nuclei``.
// Forward-declared so the classic entry point can delegate.
LatticeMatrixSet compute_1e_lattice_matrix_explicit(
    const BasisSet& basis,
    const PeriodicSystem& system,
    const std::vector<LatticeCell>& cells,
    libint2::Operator op,
    const std::vector<std::pair<double, std::array<double, 3>>>* nuclei);
LatticeMatrixSet compute_1e_lattice_matrix_explicit(
    const BasisSet& basis,
    const PeriodicSystem& system,
    const std::vector<LatticeCell>& cells,
    libint2::Operator op,
    const std::vector<std::pair<double, std::array<double, 3>>>* nuclei) {
    ensure_libint_initialized();

    const auto& shells_ref = basis.libint();
    const int nbf = static_cast<int>(basis.nbasis());

    libint2::Engine prototype(op, shells_ref.max_nprim(), shells_ref.max_l(), 0);
    if (nuclei != nullptr) {
        using nuc_params =
            libint2::operator_traits<libint2::Operator::nuclear>::oper_params_type;
        prototype.set_params(nuc_params{*nuclei});
    }
    auto engines = make_engine_pool(prototype);
    const auto shell2bf = shells_ref.shell2bf();

    LatticeMatrixSet set;
    set.nbf = nbf;
    set.cells = cells;
    set.blocks.assign(cells.size(), Eigen::MatrixXd::Zero(nbf, nbf));

    const int n_cells = static_cast<int>(cells.size());

    #pragma omp parallel for schedule(dynamic)
    for (int c = 0; c < n_cells; ++c) {
        auto& engine = engines[static_cast<std::size_t>(omp_thread_index())];
        const auto& buf = engine.results();

        const Eigen::Vector3d& g = cells[c].r_cart;
        const auto shells_g = shift_shells(shells_ref, g);

        Eigen::MatrixXd block = Eigen::MatrixXd::Zero(nbf, nbf);
        for (std::size_t s1 = 0; s1 < shells_ref.size(); ++s1) {
            const auto bf1 = shell2bf[s1];
            const auto n1 = shells_ref[s1].size();
            for (std::size_t s2 = 0; s2 < shells_g.size(); ++s2) {
                const auto bf2 = shell2bf[s2];
                const auto n2 = shells_g[s2].size();

                engine.compute(shells_ref[s1], shells_g[s2]);
                const double* tile = buf[0];
                if (!tile) continue;

                for (std::size_t i = 0; i < n1; ++i) {
                    for (std::size_t j = 0; j < n2; ++j) {
                        block(bf1 + i, bf2 + j) = tile[i * n2 + j];
                    }
                }
            }
        }
        set.blocks[c] = std::move(block);
    }
    return set;
}

LatticeMatrixSet compute_1e_lattice_matrix(
    const BasisSet& basis,
    const PeriodicSystem& system,
    const LatticeSumOptions& opts,
    libint2::Operator op,
    const std::vector<std::pair<double, std::array<double, 3>>>* nuclei) {
    const auto cells = direct_lattice_cells(system, opts.cutoff_bohr);
    return compute_1e_lattice_matrix_explicit(
        basis, system, cells, op, nuclei);
}

// Build the lattice-summed point-charge list for nuclear attraction: every
// atom of the unit cell, replicated over every lattice cell within
// opts.nuclear_cutoff_bohr.
std::vector<std::pair<double, std::array<double, 3>>>
build_periodic_nuclear_charges(const PeriodicSystem& system,
                               const LatticeSumOptions& opts) {
    const auto cells =
        direct_lattice_cells(system, opts.nuclear_cutoff_bohr);
    std::vector<std::pair<double, std::array<double, 3>>> q;
    q.reserve(cells.size() * system.unit_cell.size());
    for (const auto& c : cells) {
        for (const auto& a : system.unit_cell) {
            std::array<double, 3> r = {
                a.xyz[0] + c.r_cart[0],
                a.xyz[1] + c.r_cart[1],
                a.xyz[2] + c.r_cart[2],
            };
            q.emplace_back(static_cast<double>(a.Z), r);
        }
    }
    return q;
}

}  // namespace

LatticeMatrixSet compute_overlap_lattice(const BasisSet& basis,
                                         const PeriodicSystem& system,
                                         const LatticeSumOptions& opts) {
    return compute_1e_lattice_matrix(
        basis, system, opts, libint2::Operator::overlap, nullptr);
}

LatticeMatrixSet compute_kinetic_lattice(const BasisSet& basis,
                                         const PeriodicSystem& system,
                                         const LatticeSumOptions& opts) {
    return compute_1e_lattice_matrix(
        basis, system, opts, libint2::Operator::kinetic, nullptr);
}

LatticeMatrixSet compute_nuclear_lattice(const BasisSet& basis,
                                         const PeriodicSystem& system,
                                         const LatticeSumOptions& opts) {
    const auto nuclei = build_periodic_nuclear_charges(system, opts);
    return compute_1e_lattice_matrix(
        basis, system, opts, libint2::Operator::nuclear, &nuclei);
}

LatticeMatrixSet compute_nuclear_erfc_lattice(const BasisSet& basis,
                                              const PeriodicSystem& system,
                                              double omega,
                                              const LatticeSumOptions& opts) {
    ensure_libint_initialized();

    const auto& shells_ref = basis.libint();
    const int nbf = static_cast<int>(basis.nbasis());
    const auto nuclei = build_periodic_nuclear_charges(system, opts);

    // libint's erfc_nuclear takes a tuple (ω, point-charge list). ω is the
    // erfc attenuation parameter: the kernel is erfc(ω · r_C) / r_C, which
    // reduces to 1/r_C at ω=0 and vanishes as ω → ∞.
    libint2::Engine prototype(libint2::Operator::erfc_nuclear,
                              shells_ref.max_nprim(),
                              shells_ref.max_l(), 0);
    using erfc_params =
        libint2::operator_traits<libint2::Operator::erfc_nuclear>::oper_params_type;
    prototype.set_params(erfc_params{omega, nuclei});
    auto engines = make_engine_pool(prototype);
    const auto shell2bf = shells_ref.shell2bf();

    LatticeMatrixSet set;
    set.nbf = nbf;
    set.cells = direct_lattice_cells(system, opts.cutoff_bohr);
    set.blocks.assign(set.cells.size(), Eigen::MatrixXd::Zero(nbf, nbf));

    const int n_cells = static_cast<int>(set.cells.size());

    #pragma omp parallel for schedule(dynamic)
    for (int c = 0; c < n_cells; ++c) {
        auto& engine = engines[static_cast<std::size_t>(omp_thread_index())];
        const auto& buf = engine.results();

        const Eigen::Vector3d& g = set.cells[c].r_cart;
        const auto shells_g = shift_shells(shells_ref, g);

        Eigen::MatrixXd block = Eigen::MatrixXd::Zero(nbf, nbf);
        for (std::size_t s1 = 0; s1 < shells_ref.size(); ++s1) {
            const auto bf1 = shell2bf[s1];
            const auto n1 = shells_ref[s1].size();
            for (std::size_t s2 = 0; s2 < shells_g.size(); ++s2) {
                const auto bf2 = shell2bf[s2];
                const auto n2 = shells_g[s2].size();

                engine.compute(shells_ref[s1], shells_g[s2]);
                const double* tile = buf[0];
                if (!tile) continue;

                for (std::size_t i = 0; i < n1; ++i) {
                    for (std::size_t j = 0; j < n2; ++j) {
                        block(bf1 + i, bf2 + j) = tile[i * n2 + j];
                    }
                }
            }
        }
        set.blocks[c] = std::move(block);
    }
    return set;
}

// ---------------------------------------------------------------------------
// Phase 12e-c: full Ewald nuclear-attraction lattice sum via grid integration
// ---------------------------------------------------------------------------

namespace {

// Construct a BasisSet whose shells live at translated atomic positions.
// We reuse BasisSet's own constructor — libint will reassemble the shell
// list from the shifted Molecule exactly as it did for the original.
BasisSet shifted_basis_for_cell(const BasisSet& ref,
                                const PeriodicSystem& system,
                                const Eigen::Vector3d& dr) {
    std::vector<Atom> shifted;
    shifted.reserve(system.unit_cell.size());
    for (const auto& a : system.unit_cell) {
        shifted.push_back(Atom{
            a.Z,
            {a.xyz[0] + dr[0], a.xyz[1] + dr[1], a.xyz[2] + dr[2]},
        });
    }
    Molecule mol(std::move(shifted), system.charge, system.multiplicity);
    return BasisSet(mol, ref.name());
}

}  // namespace

LatticeMatrixSet compute_nuclear_lattice_ewald(const BasisSet& basis,
                                               const PeriodicSystem& system,
                                               const Grid& grid,
                                               const LatticeSumOptions& opts,
                                               const EwaldOptions& ewald_opts) {
    if (system.dim != 3) {
        throw std::invalid_argument(
            "compute_nuclear_lattice_ewald: 3D Ewald requires dim == 3. "
            "Use compute_nuclear_lattice (DIRECT_TRUNCATED) for 1D / 2D.");
    }

    // Resolve α (same auto-rule the Ewald engine uses).
    const double alpha = (ewald_opts.alpha > 0.0)
        ? ewald_opts.alpha
        : std::sqrt(-std::log(ewald_opts.tolerance))
              / ewald_opts.real_cutoff_bohr;

    // -----------------------------------------------------------------
    //   V(g) = V_short(g) + V_long(g)
    //
    //   V_short(g): analytical erfc-screened nuclear-attraction via
    //               libint (Phase 12e-b's compute_nuclear_erfc_lattice).
    //               Captures the sharp 1/r spike at each nucleus with
    //               full libint accuracy.
    //   V_long(g):  grid integral of the smooth, bounded long-range
    //               potential ~erf(α r)/r.
    // -----------------------------------------------------------------

    // Short-range component — analytical. compute_nuclear_erfc_lattice
    // already emits the correctly-signed electronic potential
    // (−Z erfc/r integrated against the AOs), so the sign is right.
    LatticeMatrixSet V_short =
        compute_nuclear_erfc_lattice(basis, system, alpha, opts);

    // Long-range component — grid quadrature. Evaluate the *smooth*
    // long-range Ewald potential at every grid point; the short-range
    // part is omitted (handled above analytically).
    const Eigen::VectorXd v_long_r =
        ewald_nuclear_potential(system, grid.points, ewald_opts,
                                /*include_short_range=*/false);

    const Eigen::MatrixXd chi_ref = evaluate_ao(basis, grid.points);
    const int nbf = static_cast<int>(chi_ref.cols());
    const Eigen::VectorXd wv = grid.weights.array() * v_long_r.array();
    const Eigen::MatrixXd chi_wv = chi_ref.array().colwise() * wv.array();

    LatticeMatrixSet set;
    set.nbf = nbf;
    set.cells = direct_lattice_cells(system, opts.cutoff_bohr);
    const int n_cells = static_cast<int>(set.cells.size());
    set.blocks.assign(set.cells.size(), Eigen::MatrixXd::Zero(nbf, nbf));

    // Build a map from V_short's cell-index to our cell-index so we can
    // sum the two parts per cell. compute_nuclear_erfc_lattice uses the
    // same direct_lattice_cells machinery, so the cell lists coincide
    // when the same cutoff_bohr is passed.
    if (V_short.cells.size() != set.cells.size()) {
        throw std::runtime_error(
            "compute_nuclear_lattice_ewald: short-range and long-range "
            "cell lists disagree; check LatticeSumOptions.cutoff_bohr");
    }

    #pragma omp parallel for schedule(dynamic)
    for (int c = 0; c < n_cells; ++c) {
        const BasisSet basis_g =
            shifted_basis_for_cell(basis, system, set.cells[c].r_cart);
        const Eigen::MatrixXd chi_g = evaluate_ao(basis_g, grid.points);
        //   V_long_μν(g) = Σ_r χ_μ(r) · w(r) · v_long(r) · χ_ν(r − g)
        const Eigen::MatrixXd V_long_block = chi_wv.transpose() * chi_g;
        set.blocks[c] = V_short.blocks[c] + V_long_block;
    }

    return set;
}

// ---- Public explicit-cell wrappers (Phase SYM3b, vibeqc scope) -----------

LatticeMatrixSet compute_overlap_lattice_explicit(
    const BasisSet& basis,
    const PeriodicSystem& system,
    const std::vector<LatticeCell>& cells) {
    return compute_1e_lattice_matrix_explicit(
        basis, system, cells, libint2::Operator::overlap, nullptr);
}

LatticeMatrixSet compute_kinetic_lattice_explicit(
    const BasisSet& basis,
    const PeriodicSystem& system,
    const std::vector<LatticeCell>& cells) {
    return compute_1e_lattice_matrix_explicit(
        basis, system, cells, libint2::Operator::kinetic, nullptr);
}

LatticeMatrixSet compute_nuclear_lattice_explicit(
    const BasisSet& basis,
    const PeriodicSystem& system,
    const LatticeSumOptions& opts,
    const std::vector<LatticeCell>& cells) {
    const auto nuclei = build_periodic_nuclear_charges(system, opts);
    return compute_1e_lattice_matrix_explicit(
        basis, system, cells, libint2::Operator::nuclear, &nuclei);
}

}  // namespace vibeqc
