"""Gaussian-cube volumetric data writer.

The Gaussian cube format is the de-facto interchange format for
molecular volumetric data — densities, molecular orbitals,
electrostatic potentials. Most viewers (VMD, Avogadro, PyMOL,
ChimeraX, …) read it directly.

Layout (the "Gaussian 98" convention):

    line 1  : title
    line 2  : second comment / property tag
    line 3  : N_atoms  x_origin  y_origin  z_origin   [N_val]
    line 4  : N_x      vx_x      vx_y      vx_z       (voxel along x)
    line 5  : N_y      vy_x      vy_y      vy_z
    line 6  : N_z      vz_x      vz_y      vz_z
    line 7+ : Z  charge  x  y  z   for each atom
    then    : data, scanned x_outer, z_inner; 6 floats / line

A *positive* N_atoms means scalar volumetric data and ``N_val`` is
omitted; a *negative* N_atoms signals "multiple values per voxel"
(used here for stacked MOs).

All coordinates and voxel vectors are in **bohr** — that's the cube
spec.

Public API
----------

* :func:`write_cube_density` — total electron density on a uniform
  grid (right-hand rule axis-aligned bounding box around the molecule).
* :func:`write_cube_mo` — one Kohn–Sham / Hartree–Fock MO.
* :func:`write_cube_mos` — a stack of MOs in a single multi-value cube
  (consumed by VMD / Avogadro as a multi-frame volume).

A grid helper :func:`make_uniform_grid` wraps the typical "wrap a box
around the molecule with N_x × N_y × N_z voxels and ``padding`` bohr
of breathing room" pattern. Pass it a custom ``origin`` /
``spacing`` if you need a finer / coarser / off-center grid.
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Optional, Sequence, Tuple, Union

import numpy as np

from ._vibeqc_core import BasisSet, Molecule, evaluate_ao


__all__ = [
    "CubeGrid",
    "make_uniform_grid",
    "write_cube_density",
    "write_cube_mo",
    "write_cube_mos",
]


# ---------------------------------------------------------------------------
# Grid container + builder
# ---------------------------------------------------------------------------

@dataclass
class CubeGrid:
    """Uniform axis-aligned grid in bohr.

    ``origin`` is the (x, y, z) of voxel (0, 0, 0). ``spacing`` is the
    diagonal of three voxel widths along each Cartesian axis.
    Non-orthogonal cube grids are supported by the spec but uncommon in
    QC viewers; we keep things axis-aligned.
    """

    origin: np.ndarray              # (3,)  bohr
    spacing: np.ndarray             # (3,)  bohr (positive)
    shape: Tuple[int, int, int]     # (n_x, n_y, n_z)

    @property
    def n_points(self) -> int:
        return int(np.prod(self.shape))

    def points(self) -> np.ndarray:
        """All voxel centers as a ``(n_points, 3)`` array, scanned with
        x as the outer index and z as the inner index — the order the
        cube format expects."""
        nx, ny, nz = self.shape
        ix = np.arange(nx) * self.spacing[0] + self.origin[0]
        iy = np.arange(ny) * self.spacing[1] + self.origin[1]
        iz = np.arange(nz) * self.spacing[2] + self.origin[2]
        # Broadcast in (x, y, z) order — outer x, inner z.
        X, Y, Z = np.meshgrid(ix, iy, iz, indexing="ij")
        return np.column_stack([X.ravel(), Y.ravel(), Z.ravel()])


def make_uniform_grid(
    mol: Molecule,
    *,
    spacing: float = 0.2,
    padding: float = 4.0,
) -> CubeGrid:
    """Wrap an axis-aligned box around the molecule with ``padding``
    bohr of headroom and a cubic voxel of ``spacing`` bohr.

    ``spacing = 0.2`` and ``padding = 4`` gives roughly 100³ ≈ 10⁶
    voxels for a small molecule — enough to render densities and MOs
    smoothly and small enough to write in a second.
    """
    coords = np.asarray([list(a.xyz) for a in mol.atoms], dtype=float)
    if coords.ndim == 1:
        coords = coords.reshape(-1, 3)
    if coords.size == 0:
        raise ValueError("make_uniform_grid: molecule has no atoms")
    lo = coords.min(axis=0) - padding
    hi = coords.max(axis=0) + padding
    extent = hi - lo
    n = np.ceil(extent / spacing).astype(int) + 1
    return CubeGrid(
        origin=lo,
        spacing=np.array([spacing, spacing, spacing]),
        shape=(int(n[0]), int(n[1]), int(n[2])),
    )


# ---------------------------------------------------------------------------
# Volumetric scalar evaluation
# ---------------------------------------------------------------------------

def _density_on_grid(
    D: np.ndarray, basis: BasisSet, grid: CubeGrid,
    *, chunk_size: int = 200_000,
) -> np.ndarray:
    """ρ(r) = Σ_{μν} D_{μν} χ_μ(r) χ_ν(r), evaluated voxel-by-voxel in
    chunks so the (n_points, n_basis) AO matrix never exceeds a few
    hundred MB on big grids."""
    pts = grid.points()
    out = np.empty(pts.shape[0], dtype=float)
    for i in range(0, pts.shape[0], chunk_size):
        block = pts[i:i + chunk_size]
        chi = evaluate_ao(basis, block)         # (m, nb)
        # ρ = sum_{μν} D_{μν} χ_μ χ_ν = Σ_μ χ_μ (D χᵀ)_μ
        out[i:i + chunk_size] = np.einsum("mi,ij,mj->m", chi, D, chi)
    return out.reshape(grid.shape)


def _mo_on_grid(
    C_col: np.ndarray, basis: BasisSet, grid: CubeGrid,
    *, chunk_size: int = 200_000,
) -> np.ndarray:
    """φ(r) = Σ_μ C_μ χ_μ(r) for one MO (vector ``C_col``)."""
    pts = grid.points()
    out = np.empty(pts.shape[0], dtype=float)
    for i in range(0, pts.shape[0], chunk_size):
        block = pts[i:i + chunk_size]
        chi = evaluate_ao(basis, block)
        out[i:i + chunk_size] = chi @ C_col
    return out.reshape(grid.shape)


# ---------------------------------------------------------------------------
# Cube file writer
# ---------------------------------------------------------------------------

def _format_atoms_block(mol: Molecule) -> str:
    lines = []
    for a in mol.atoms:
        x, y, z = a.xyz
        # The "atomic charge" column is conventionally the nuclear
        # charge as a float — viewers read Z and ignore the value, but
        # the canonical convention is to put Z there too.
        lines.append(
            f"{a.Z:5d} {float(a.Z):12.6f} {x:12.6f} {y:12.6f} {z:12.6f}"
        )
    return "\n".join(lines)


def _write_cube_header(
    out, *, title: str, comment: str, mol: Molecule,
    grid: CubeGrid, n_values: Optional[int] = None,
    extra_header_int: Optional[int] = None,
    extra_header_ints: Optional[Sequence[int]] = None,
) -> None:
    out.write(f"{title}\n")
    out.write(f"{comment}\n")
    n_atoms = len(mol.atoms)
    if n_values is not None and n_values > 1:
        # Multi-value cube: negate atom count, append number of values.
        out.write(
            f"{-n_atoms:5d} {grid.origin[0]:12.6f} {grid.origin[1]:12.6f} "
            f"{grid.origin[2]:12.6f} {n_values:5d}\n"
        )
    else:
        out.write(
            f"{n_atoms:5d} {grid.origin[0]:12.6f} {grid.origin[1]:12.6f} "
            f"{grid.origin[2]:12.6f}\n"
        )
    nx, ny, nz = grid.shape
    out.write(f"{nx:5d} {grid.spacing[0]:12.6f}     0.000000     0.000000\n")
    out.write(f"{ny:5d}     0.000000 {grid.spacing[1]:12.6f}     0.000000\n")
    out.write(f"{nz:5d}     0.000000     0.000000 {grid.spacing[2]:12.6f}\n")
    out.write(_format_atoms_block(mol))
    out.write("\n")
    # For multi-value cubes a line with "<n_values>  <id_0>  <id_1> ..."
    # follows the atoms — viewers use it to label the volumes.
    if extra_header_ints is not None:
        toks = [str(len(extra_header_ints))] + [str(i) for i in extra_header_ints]
        out.write(" " + "  ".join(toks) + "\n")


def _write_cube_data(out, data: np.ndarray) -> None:
    """Write a (..., n_z) array (or stacked) flat with 6 numbers per
    line. Cube wants z as the inner running index."""
    flat = data.reshape(-1)
    for i in range(0, flat.size, 6):
        chunk = flat[i:i + 6]
        out.write(" ".join(f"{x:13.5e}" for x in chunk) + "\n")


def write_cube_density(
    path: Union[str, Path],
    D: np.ndarray,
    basis: BasisSet,
    mol: Molecule,
    *,
    grid: Optional[CubeGrid] = None,
    spacing: float = 0.2,
    padding: float = 4.0,
    title: str = "vibe-qc electron density",
    comment: str = "rho(r) in e/bohr^3",
) -> Path:
    """Write the total electron density ρ(r) = ⟨D, χ⊗χ⟩ to a Gaussian
    cube file.

    For UHF / UKS pass ``D = D_alpha + D_beta`` (the total density).
    """
    if grid is None:
        grid = make_uniform_grid(mol, spacing=spacing, padding=padding)
    rho = _density_on_grid(np.asarray(D, dtype=float), basis, grid)

    p = Path(path)
    with p.open("w") as out:
        _write_cube_header(out, title=title, comment=comment, mol=mol, grid=grid)
        _write_cube_data(out, rho)
    return p


def write_cube_mo(
    path: Union[str, Path],
    C: np.ndarray,
    index: int,
    basis: BasisSet,
    mol: Molecule,
    *,
    grid: Optional[CubeGrid] = None,
    spacing: float = 0.2,
    padding: float = 4.0,
    title: Optional[str] = None,
) -> Path:
    """Write a single molecular orbital ``φ_index(r)`` to a cube file.

    ``C`` is the full MO coefficient matrix (rows = AOs, columns = MOs);
    ``index`` is zero-based.
    """
    C = np.asarray(C, dtype=float)
    if not (0 <= index < C.shape[1]):
        raise IndexError(
            f"MO index {index} out of range for C with shape {C.shape}"
        )
    if grid is None:
        grid = make_uniform_grid(mol, spacing=spacing, padding=padding)
    phi = _mo_on_grid(C[:, index], basis, grid)

    p = Path(path)
    with p.open("w") as out:
        _write_cube_header(
            out,
            title=title or f"vibeqc MO {index}",
            comment="phi(r), units: 1/bohr^(3/2)",
            mol=mol, grid=grid,
        )
        _write_cube_data(out, phi)
    return p


def write_cube_mos(
    path: Union[str, Path],
    C: np.ndarray,
    indices: Iterable[int],
    basis: BasisSet,
    mol: Molecule,
    *,
    grid: Optional[CubeGrid] = None,
    spacing: float = 0.2,
    padding: float = 4.0,
    title: str = "vibe-qc MOs",
) -> Path:
    """Write a stack of MOs in a single multi-value cube file.

    The output is a "negative N_atoms" cube — VMD's *Volumetric data*
    selector lets you switch between MOs in one window.
    """
    C = np.asarray(C, dtype=float)
    idx = list(indices)
    if not idx:
        raise ValueError("write_cube_mos: indices is empty")
    for i in idx:
        if not (0 <= i < C.shape[1]):
            raise IndexError(
                f"MO index {i} out of range for C with shape {C.shape}"
            )

    if grid is None:
        grid = make_uniform_grid(mol, spacing=spacing, padding=padding)

    # Evaluate AOs once per chunk and contract against every requested
    # MO column at once.
    pts = grid.points()
    n_v = len(idx)
    out_data = np.empty((pts.shape[0], n_v), dtype=float)
    chunk = 200_000
    C_sel = C[:, idx]  # (nbf, n_v)
    for i in range(0, pts.shape[0], chunk):
        block = pts[i:i + chunk]
        chi = evaluate_ao(basis, block)
        out_data[i:i + chunk, :] = chi @ C_sel
    # Cube wants voxel-major then n_v: shape (nx, ny, nz, n_v).
    out_data = out_data.reshape((*grid.shape, n_v))

    p = Path(path)
    with p.open("w") as out:
        _write_cube_header(
            out,
            title=title,
            comment="MO stack: phi_i(r) for indices " + ",".join(map(str, idx)),
            mol=mol, grid=grid, n_values=n_v,
            extra_header_ints=[i + 1 for i in idx],   # 1-based labels
        )
        _write_cube_data(out, out_data)
    return p
