"""User-selectable J/K (Coulomb + exchange) method for periodic SCF,
plus an intelligent AUTO picker.

vibe-qc supports several alternatives to FFT-Poisson for the periodic
two-electron Hartree-Fock / KS build. Each has a different sweet spot
in (basis quality, lattice shape, system size, accuracy target). This
module provides:

* :class:`PeriodicJKMethod` — enum naming each method.
* :func:`pick_jk_method` — heuristic that turns ``AUTO`` into a concrete
  choice given the system + basis + cell shape.
* :func:`validate_jk_method` — hard error / warning on illegal or
  risky combinations.
* :func:`describe_jk_method` — short human-readable label for logs.

Status by method:

  Method                  Status
  ---------------------   ---------------------------------------------
  GDF                     Γ-only RHF/RKS/hybrid native implementation
                            available. Multi-k GDF is pending. The
                            earlier PySCF-backed spike is retired:
                            PySCF and CRYSTAL are external reference
                            programs, not in-process vibe-qc backends.
  BIPOLE                  ✓ multi-k RHF/UHF/RKS/UKS via CRYSTAL-gauge
                            Ewald J-split: shared Ewald α, analytic
                            J^LR, direct J^SR + K (HF) or V_xc (DFT),
                            optional multipole far-field acceleration.
  DIRECT                  Limited scope. Plumbed against the existing
                            C++ ``build_fock_2e_real_space`` which does
                            the proper triple-cell-sum + Schwarz
                            screening on each pair displacement. BUT —
                            with omega=0 (full Coulomb), the truncated
                            real-space sum DIVERGES for tight ionic
                            crystals (verified on MgO sto-3g: max|J−J_py|
                            ≈ 363 Ha at cutoff 18 bohr). DIRECT is only
                            valid for vacuum-padded (molecular-limit)
                            cells where the cutoff actually clips the
                            density. AUTO will refuse to pick DIRECT for
                            non-vacuum-padded crystals. Native GDF/FFTDF
                            is the production target.
  FFT_POISSON             ✓ native Ewald-split FFT-Poisson path.
                            Supports skew cells at the FFT metric/grid
                            level; still legacy for production ionic
                            crystals until the native FFTDF/GDF parity
                            work replaces the old EWALD_3D stack.
                            Kept selectable for parity tests.
  GPW                     not yet implemented (v0.10.x M2 deliverable).
                            Gaussian + plane-wave Hartree-J via
                            FFT-Poisson on a smooth real-space grid;
                            pseudopotential-only. Enum + dispatch
                            wired at M1; driver lands at M2 (target:
                            sub-mHa Si bulk vs CP2K). See
                            docs/design_periodic_gapw.md.
  GAPW                    not yet implemented (v0.10.x M3 deliverable).
                            Lippert & Hutter Gaussian-augmented
                            plane-wave; adds per-atom radial
                            augmentation on top of GPW for
                            all-electron accuracy. Same design doc.
  RSGDF                   not yet implemented (range-separated GDF;
                            Ye & Berkelbach, J. Chem. Phys. 154,
                            131104 (2021), DOI 10.1063/5.0046617).
  CFMM                    not yet implemented (continuous fast-multipole;
                            White et al. CPL 230, 8).

The AUTO picker chooses based on:

  - **Lattice shape**: all implemented methods accept general 3D
    lattice vectors at the API level; method quality is still basis-
    and system-dependent.
  - **Basis compactness**: a "compact" basis (sto-3g, minimal) lets
    DIRECT converge; diffuse bases (cc-pVTZ, def2-TZVP) need GDF or
    RSGDF.
  - **System size**: huge supercells favor CFMM (when implemented).

The default policy does not silently fall back to a method that is known
to be scientifically unsafe for tight ionic crystals. Until native
FFTDF/GDF lands, ``AUTO`` raises with an explicit explanation.
"""

from __future__ import annotations

import enum
import warnings
from typing import Optional

import numpy as np

__all__ = [
    "PeriodicJKMethod",
    "pick_jk_method",
    "validate_jk_method",
    "describe_jk_method",
    "is_orthorhombic",
]


class PeriodicJKMethod(enum.Enum):
    """User-selectable periodic J/K builder.

    Use ``AUTO`` to let vibe-qc choose; pass any other value to force
    a specific method. The resolved method is logged in the output
    file's banner so reproducibility doesn't depend on the AUTO
    heuristic version.
    """

    AUTO = "auto"
    GDF = "gdf"  # Γ-only RHF/RKS native; multi-k pending
    BIPOLE = "bipole"  # CRYSTAL-gauge Ewald J-split RHF/UHF/RKS/UKS (multi-k);
    # multi-k; shared Ewald α; hybrids OK
    DIRECT = "direct"  # partial — see module docstring
    FFT_POISSON = "fft_poisson"  # implemented; native FFT metric
    # GPW / GAPW: v0.10.x acceleration route (Lippert & Hutter 1999;
    # docs/design_periodic_gapw.md). GPW = pseudopotential-only smooth
    # plane-wave J; GAPW = all-electron via Gaussian augmentation on
    # the smooth grid. Kept as separate values (not one flag) per
    # design doc § 9 decision 3 — distinct user-facing capability
    # claim, clearer dispatch. Both raise NotImplementedError until
    # the M2 driver lands (M1 only wires the enum + experimental
    # flag).
    GPW = "gpw"  # not yet implemented (v0.10.x M2 deliverable)
    GAPW = "gapw"  # not yet implemented (v0.10.x M3 deliverable)
    RSGDF = "rsgdf"  # not yet implemented
    CFMM = "cfmm"  # not yet implemented


# Methods that are actually wired up natively.
_IMPLEMENTED: frozenset[PeriodicJKMethod] = frozenset(
    {
        PeriodicJKMethod.GDF,  # Γ-only RHF/RKS in periodic_runner
        PeriodicJKMethod.BIPOLE,  # multi-k RHF/UHF/RKS/UKS via pbc_bipole
        PeriodicJKMethod.DIRECT,  # only valid for vacuum-padded cells
        PeriodicJKMethod.FFT_POISSON,
        # PeriodicJKMethod.GPW,    # M2 — Hartree-only smooth-grid J
        # PeriodicJKMethod.GAPW,   # M3 — adds Gaussian augmentation
        # PeriodicJKMethod.RSGDF,
        # PeriodicJKMethod.CFMM,
    }
)


# ============================================================
# Helpers
# ============================================================


def is_orthorhombic(lattice: np.ndarray, tol: float = 1e-10) -> bool:
    """True iff ``lattice`` is diagonal (axis-aligned orthorhombic).

    Compatibility helper for older routing code and docs. ``lattice``
    is the (3, 3) Cartesian matrix whose columns are a₁, a₂, a₃.
    """
    L = np.asarray(lattice, dtype=float)
    off_diag = L - np.diag(np.diag(L))
    diag_norm = max(np.linalg.norm(np.diag(L)), 1.0)
    return float(np.max(np.abs(off_diag))) < tol * diag_norm


def _basis_is_compact(basis_name: str) -> bool:
    """Is this a compact (minimal / sto-3g-like) basis where DIRECT
    can converge with manageable cell-cutoff?

    Heuristic: name starts with "sto" (any zeta count) or contains
    "minimal". Could be extended once DIRECT is back online.
    """
    n = (basis_name or "").lower()
    return n.startswith("sto") or "minimal" in n


# ============================================================
# AUTO picker
# ============================================================


def pick_jk_method(
    method: PeriodicJKMethod | str,
    *,
    lattice: np.ndarray,
    basis_name: str,
    n_atoms: int,
    scf_method: str = "RHF",
) -> PeriodicJKMethod:
    """Resolve an ``AUTO`` choice into a concrete :class:`PeriodicJKMethod`.

    Returns ``method`` unchanged if it is already concrete (and
    implemented). Raises if the resolved method is not implemented or
    is incompatible with the lattice / basis.

    When ``AUTO``: prefers GDF for RHF/RKS (Γ-only closed-shell),
    falls back to BIPOLE for UHF/UKS (open-shell not yet supported
    by the GDF driver).
    """
    if isinstance(method, str):
        try:
            method = PeriodicJKMethod(method.lower())
        except ValueError:
            valid = ", ".join(m.value for m in PeriodicJKMethod)
            raise ValueError(f"Unknown periodic JK method: {method!r}. Valid: {valid}")
    if method != PeriodicJKMethod.AUTO:
        return method

    # AUTO heuristic. Prefer GDF for closed-shell (RHF/RKS), BIPOLE
    # for open-shell (UHF/UKS) since GDF does not support UHF/UKS yet.
    if scf_method.upper() in ("UHF", "UKS"):
        return PeriodicJKMethod.BIPOLE
    return PeriodicJKMethod.GDF


# ============================================================
# Validation
# ============================================================


def validate_jk_method(
    method: PeriodicJKMethod,
    *,
    lattice: np.ndarray,
    basis_name: str,
) -> None:
    """Raise / warn on illegal or risky combinations."""
    # GPW / GAPW are wired in `vibeqc.run_periodic_rhf_gpw` (M2-full)
    # but not yet plumbed through `run_periodic_job`'s dispatch
    # table — the surrounding plumbing (scf_trace serialisation,
    # MO summary, .system manifest) needs `GpwScfResult` adapted to
    # the same shape as the GDF / BIPOLE result types first. Until
    # that adapter lands, hard-error here with a specific pointer
    # at the working standalone entry rather than the generic
    # "not yet wired up natively" message.
    if method == PeriodicJKMethod.GPW:
        raise NotImplementedError(
            "PeriodicJKMethod.GPW is not yet plumbed through "
            "`run_periodic_job`. Use `vibeqc.run_periodic_rhf_gpw"
            "(system, basis, *, grid=None, cutoff_ha=300)` instead "
            "— the standalone GPW SCF entry. Closed-shell RHF "
            "only at M2-full; DFT / open-shell are M3 work. See "
            "`docs/design_periodic_gapw.md` § GPW."
        )
    if method == PeriodicJKMethod.GAPW:
        raise NotImplementedError(
            "PeriodicJKMethod.GAPW is not yet implemented — "
            "the M3 Gaussian augmentation on top of the M2 GPW "
            "smooth-grid J. For the smooth-grid-only GPW route "
            "(no augmentation), use "
            "`vibeqc.run_periodic_rhf_gpw(system, basis, ...)`. "
            "See `docs/design_periodic_gapw.md` § GAPW."
        )
    if method not in _IMPLEMENTED:
        raise NotImplementedError(
            f"Periodic JK method {method.value!r} is not yet wired up "
            f"natively. Currently available: "
            f"{ {m.value for m in _IMPLEMENTED} }"
        )

    if method == PeriodicJKMethod.GDF:
        return

    if method == PeriodicJKMethod.BIPOLE:
        # BIPOLE requires 3D (the Ewald gauge is 3D-only).
        # 1D/2D fall back to DIRECT_TRUNCATED unless the user
        # opts into another method.
        L = np.asarray(lattice, dtype=float)
        if np.linalg.matrix_rank(L) < 3:
            raise ValueError(
                "PeriodicJKMethod.BIPOLE requires a 3D lattice. "
                "For 1D/2D systems use GDF or DIRECT."
            )
        return

    if method == PeriodicJKMethod.FFT_POISSON:
        # Soft warning — the legacy FFT-Poisson Ewald path is known
        # broken on ionic crystals (see docs/handover_periodic_scf_rewrite.md).
        # Allow the user to opt in for parity testing, but tell them.
        warnings.warn(
            "PeriodicJKMethod.FFT_POISSON dispatches to the legacy "
            "EWALD_3D code path. That path is **known broken** on "
            "ionic crystals (e.g., MgO RHF was off by +241 Ha in the "
            "v0.7.0 parity sweep). It is kept selectable for "
            "regression testing while native FFTDF/GDF is built.",
            stacklevel=2,
        )

    if method == PeriodicJKMethod.DIRECT:
        # build_fock_2e_real_space at omega=0 is real-space-truncated,
        # NOT Ewald-split. The truncated 1/r lattice sum diverges for
        # tight ionic crystals. Verified on MgO sto-3g cutoff=18:
        # max|J−J_pyscf| ≈ 363 Ha. DIRECT is only valid for cells
        # padded enough that the Coulomb tail outside cutoff is small.
        warnings.warn(
            "PeriodicJKMethod.DIRECT (build_fock_2e_real_space, omega=0) "
            "is only valid for VACUUM-PADDED (molecular-limit) cells. "
            "On tight ionic crystals the truncated real-space Coulomb "
            "lattice sum diverges (e.g., MgO sto-3g shows a historical "
            "hundreds-of-Hartree J error at cutoff 18 bohr). Native "
            "GDF/FFTDF is the target for tight crystals; DIRECT is "
            "appropriate for vacuum-padded molecules in PBC "
            "boxes (Makov-Payne regime). Future Method 2 augmentation "
            "would add a reciprocal-space LR contribution to make it "
            "valid for tight crystals (= Ewald composition).",
            stacklevel=2,
        )


# ============================================================
# Description
# ============================================================

_DESCRIPTIONS = {
    PeriodicJKMethod.AUTO: "AUTO — pick at runtime",
    PeriodicJKMethod.GDF: "GDF — Gaussian density fitting (native Γ-only RHF/RKS/hybrids; "
    "PySCF/CRYSTAL are external references only)",
    PeriodicJKMethod.BIPOLE: "BIPOLE — CRYSTAL-gauge Ewald J-split (multi-k; "
    "shared Ewald α across V_ne/E_nn/J^LR; "
    "all methods RHF/UHF/RKS/UKS; hybrids OK)",
    PeriodicJKMethod.DIRECT: "DIRECT — 4-center periodic lattice sum + Schwarz screening "
    "(partial: missing-σ_q bug)",
    PeriodicJKMethod.FFT_POISSON: "FFT_POISSON — Ewald-split with FFT-Poisson long-range "
    "(native FFT metric; legacy EWALD_3D path; known broken on "
    "ionic crystals)",
    PeriodicJKMethod.GPW: "GPW — Gaussian + plane-wave J via FFT-Poisson on a smooth "
    "real-space grid (Hartree-only; pseudopotentials required for "
    "heavy atoms; v0.10.x M2 — not yet implemented)",
    PeriodicJKMethod.GAPW: "GAPW — Gaussian-augmented plane-wave J (Lippert-Hutter); "
    "smooth plane-wave grid + per-atom radial augmentation gives "
    "all-electron accuracy on top of GPW (v0.10.x M3 — not yet implemented)",
    PeriodicJKMethod.RSGDF: "RSGDF — range-separated GDF (not yet implemented)",
    PeriodicJKMethod.CFMM: "CFMM — continuous fast multipole (not yet implemented)",
}


def describe_jk_method(method: PeriodicJKMethod) -> str:
    """Short human-readable description for output logs."""
    return _DESCRIPTIONS.get(method, str(method))
