"""Phase G1-GDF — analytic Γ-only periodic GDF (compcell) atomic gradient.

The GDF two-electron energy at Γ with density D:

  E_2e = ½ tr(D · J) − ¼ α_HF · tr(D · K)

where J and K are built from the compcell cderi L via
:func:`vibeqc.pbc_gdf._build_j_from_lpq` / :func:`_build_k_from_lpq`.

**J gradient (Coulomb).** The DF-J gradient formula:

  ∂E_J/∂R = Σ_P γ̃_P Σ_μν D_μν ∂T[P,μν]/∂R − ½ Σ_PQ γ̃_P γ̃_Q ∂M_PQ/∂R

where γ̃ = M^{-1} ρ, ρ_P = Σ_μν T[P,μν] D_μν, and (M, T) are the
compensated 2c metric and 3c tensor.

The compensation matrix A maps the fused basis (modrho-aux + compensating
charges) to the physical aux space: M = A·M_fused·A^T, T = A·T_fused.
The gradient maps back onto the fused basis via γ̃_fused = A^T·γ̃.

**K gradient (exchange).** The DF-K gradient:

  ∂E_K/∂R = +α_HF Σ_PQ ω_PQ ∂M_PQ/∂R − 2 α_HF Σ_{P,μν} Y^P_{μν} ∂(P|μν)/∂R

with ω = (η^P : η^Q), η = M^{-1} K_occ, K^P_occ,ij = C_occ^T T^P C_occ,
Y^P = C_occ · η^P · C_occ^T. For pure DFT (α_HF = 0) this term is zero.

**Current scope (Increment 1):** Γ RHF, compcell cderi only (no PW mesh).
J gradient only (pure DFT / J-only validation). K and α_HF > 0 in
Increment 3. Implementation uses the new C++ lattice-summed gradient
kernels:
- :func:`vibeqc._vibeqc_core.compute_2c_eri_lattice_gradient_weighted`
- :func:`vibeqc._vibeqc_core.compute_3c_eri_lattice_gradient_weighted`
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np

from ._vibeqc_core import (
    BasisSet,
    LatticeSumOptions,
    PeriodicSystem,
    compute_2c_eri_lattice,
    compute_2c_eri_lattice_gradient_weighted,
    compute_3c_eri_lattice,
    compute_3c_eri_lattice_gradient_weighted,
    compute_overlap_lattice,
    kinetic_lattice_gradient_contribution,
    nuclear_repulsion_gradient_per_cell,
    overlap_lattice_gradient_contribution,
)
from .aux_basis import (
    fuse_transform_matrix,
    make_compensating_basis,
    make_fused_basis,
    make_modrho_aux_basis,
)

__all__ = ["compute_gdf_gradient_rhf_gamma", "compute_gdf_gradient"]


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _gamma_density_lattice_set(template_lat, D: np.ndarray) -> np.ndarray:
    """Build a Γ-only lattice-resolved density. For the Γ-point, the
    real-space density is identical in every image cell (k=0 Bloch
    phase = 1)."""
    # Actually, for the GDF route at Γ, the density is the Γ-folded
    # single-cell density. The lattice-resolved D(g) = D_Γ for all g.
    # But for the gradient, the ERI lattice gradient uses a
    # LatticeMatrixSet. Since we compute the 2e gradient through the
    # DF path (not through the 4c ERI path), we don't need D(g).
    # The D passed to the 3c gradient kernel is the single-cell density.
    return np.asarray(D, dtype=np.float64)


@dataclass
class _CompcellGradientCache:
    """Cached intermediates from the compcell pipeline for gradient use.

    The compensation matrix A and the fused basis integrals M_fused
    and T_fused depend only on the geometry, not on the density. They
    are computed once and reused for every gradient evaluation.
    """

    A: np.ndarray  # (n_aux, n_fused) compensation matrix
    M_fused: np.ndarray  # (n_fused, n_fused) pre-compensation 2c metric
    T_fused: np.ndarray  # (n_fused, n_orb, n_orb) pre-compensation 3c tensor
    fused_basis: BasisSet  # the fused basis (modrho-aux + chg)
    n_fused: int
    n_aux: int
    n_orb: int


def _build_compcell_gradient_cache(
    system: PeriodicSystem,
    ao_basis: BasisSet,
    aux_basis: BasisSet,
    compcell_eta: float,
    lat_opts_2c: LatticeSumOptions,
    lat_opts_3c: LatticeSumOptions,
) -> _CompcellGradientCache:
    """Build the geometry-dependent (density-independent) cache for
    compcell GDF gradient evaluation.

    Repeats the first half of :func:`vibeqc.aux_basis.build_lpq_compcell`
    (modrho → chg → fused → A → M_fused / T_fused) and stores the
    intermediates needed to form the gradient weights.
    """
    mol = system.unit_cell_molecule()
    modrho_aux = make_modrho_aux_basis(aux_basis, mol)
    chg = make_compensating_basis(modrho_aux, mol, eta=float(compcell_eta))
    fused = make_fused_basis(modrho_aux, chg, mol)
    A_mat = fuse_transform_matrix(modrho_aux, chg)

    M_fused = np.asarray(compute_2c_eri_lattice(fused, system, lat_opts_2c))
    M_fused = 0.5 * (M_fused + M_fused.T)
    T_fused = np.asarray(compute_3c_eri_lattice(ao_basis, fused, system, lat_opts_3c))

    # Apply AFT correction on the 2c side (same as build_lpq_compcell).
    # We do NOT apply AFT to T_fused — the compcell pipeline notes that
    # the 3c AFT cancels after eigendecomposition-and-threshold, so we
    # keep the bare 3c for gradient consistency with the SCF energy.
    # For now (Increment 1), we skip the AFT correction to keep the
    # implementation simple and validate against FD first.
    # TODO: add AFT correction on M_fused for consistency with SCF when
    # apply_aft_correction=True.

    n_aux = modrho_aux.nbasis
    n_orb = ao_basis.nbasis
    n_fused = fused.nbasis

    return _CompcellGradientCache(
        A=A_mat,
        M_fused=M_fused,
        T_fused=T_fused,
        fused_basis=fused,
        n_fused=n_fused,
        n_aux=n_aux,
        n_orb=n_orb,
    )


def _compute_j_gradient_compcell(
    system: PeriodicSystem,
    ao_basis: BasisSet,
    D: np.ndarray,
    cache: _CompcellGradientCache,
    lat_opts_2c: LatticeSumOptions,
    lat_opts_3c: LatticeSumOptions,
) -> np.ndarray:
    """Compute the DF-J (Coulomb) analytic gradient via the compcell path.

    Uses the standard DF gradient formula mapped onto the fused basis
    via the compensation matrix A.

    Parameters
    ----------
    system, ao_basis, D
        Periodic system, orbital basis, and converged density (n_orb, n_orb).
    cache
        Pre-built geometry cache from :func:`_build_compcell_gradient_cache`.
    lat_opts_2c, lat_opts_3c
        Lattice-sum options for 2c and 3c integrals (must match SCF).

    Returns
    -------
    grad_J : np.ndarray of shape (n_atoms, 3) in Hartree/bohr.
    """
    n_aux = cache.n_aux
    n_fused = cache.n_fused
    n_orb = cache.n_orb

    # Step 1: Compensate M and T.
    # M_comp = A @ M_fused @ A^T  (n_aux, n_aux)
    # T_comp = A @ T_fused        (n_aux, n_orb, n_orb)
    A = cache.A
    M_comp = A @ cache.M_fused @ A.T
    M_comp = 0.5 * (M_comp + M_comp.T)
    T_comp = np.einsum("iP,Pmn->imn", A, cache.T_fused, optimize=True)

    # Step 2: Contract T with D to get ρ.
    # ρ_P = Σ_μν T_comp[P, μ, ν] · D[μ, ν]
    T_flat = T_comp.reshape(n_aux, n_orb * n_orb)
    D_flat = np.asarray(D, dtype=np.float64).ravel()
    rho = T_flat @ D_flat  # (n_aux,)

    # Step 3: Solve M_comp · γ̃ = ρ.
    gamma_tilde = np.linalg.solve(M_comp, rho)  # (n_aux,)

    # Step 4: Project γ̃ back onto the fused basis.
    # γ̃_fused = A^T · γ̃  (n_fused,)
    gamma_tilde_fused = A.T @ gamma_tilde  # (n_fused,)

    # Step 5: Build the 2c weight Ω = γ̃_fused γ̃_fused^T.
    # The DF-J gradient passes Ω = -(1/2) γ γ^T to the 2c kernel.
    Omega = np.outer(gamma_tilde_fused, gamma_tilde_fused)  # (n_fused, n_fused)

    # Step 6: Build the 3c weight W.
    # W[P, μν] = γ̃_fused[P] · D[μν]  (n_fused, n_orb × n_orb) row-major
    W = np.outer(gamma_tilde_fused, D_flat)  # (n_fused, n_orb²)
    W = np.ascontiguousarray(W)  # ensure row-major for C++

    # Step 7: Call the C++ lattice-summed gradient kernels.
    grad_2c = np.asarray(
        compute_2c_eri_lattice_gradient_weighted(
            cache.fused_basis, system, lat_opts_2c, -0.5 * Omega
        ),
        dtype=np.float64,
    )
    grad_3c = np.asarray(
        compute_3c_eri_lattice_gradient_weighted(
            ao_basis, cache.fused_basis, system, lat_opts_3c, W
        ),
        dtype=np.float64,
    )

    return grad_2c + grad_3c


def _compute_k_gradient_compcell(
    system: PeriodicSystem,
    ao_basis: BasisSet,
    D: np.ndarray,
    C_occ: np.ndarray,
    alpha_hf: float,
    cache: _CompcellGradientCache,
    lat_opts_2c: LatticeSumOptions,
    lat_opts_3c: LatticeSumOptions,
) -> np.ndarray:
    """Compute the DF-K (exchange) analytic gradient via the compcell path.

    For pure DFT (alpha_hf = 0) returns zero.

    Parameters
    ----------
    system, ao_basis, D, C_occ, alpha_hf
        Same conventions as the molecular DF-K gradient.
    cache, lat_opts_*
        Same as :func:`_compute_j_gradient_compcell`.

    Returns
    -------
    grad_K : np.ndarray of shape (n_atoms, 3) in Hartree/bohr.
    """
    if alpha_hf == 0.0:
        return np.zeros((len(system.unit_cell), 3), dtype=np.float64)

    n_aux = cache.n_aux
    n_fused = cache.n_fused
    n_orb = cache.n_orb
    n_occ = C_occ.shape[1]

    A = cache.A
    M_comp = A @ cache.M_fused @ A.T
    M_comp = 0.5 * (M_comp + M_comp.T)
    T_comp = np.einsum("iP,Pmn->imn", A, cache.T_fused, optimize=True)

    # Step 1: Build M^P_ij = C_occ^T · T^P · C_occ  (K_occ)
    # M_occ[P, i, j] for each aux function P.
    M_occ = np.einsum("Pmn,mi,nj->Pij", T_comp, C_occ, C_occ, optimize=True)
    # M_occ shape: (n_aux, n_occ, n_occ)

    # Step 2: Solve M_comp · η^P = M_occ (column-wise).
    # η[i, j, P] reshaped for solve.
    M_occ_flat = M_occ.reshape(n_aux, n_occ * n_occ)
    eta_flat = np.linalg.solve(M_comp, M_occ_flat)  # (n_aux, n_occ²)
    eta = eta_flat.reshape(n_aux, n_occ, n_occ)

    # Step 3: Build ω_{PQ} = Σ_{ij} η^P_ij · η^Q_ij.
    omega = np.einsum("Pij,Qij->PQ", eta, eta, optimize=True)  # (n_aux, n_aux)
    omega_fused = A.T @ omega @ A  # (n_fused, n_fused)

    # Step 4: Build Y^P_{μν} = (C_occ · η^P · C_occ^T)_{μν}.
    Y = np.einsum("Pij,mi,nj->Pmn", eta, C_occ, C_occ, optimize=True)
    # Y shape: (n_aux, n_orb, n_orb)
    Y_fused = np.einsum("Pi,Pmn->imn", A, Y, optimize=True)
    # Y_fused shape: (n_fused, n_orb, n_orb)
    Y_flat = Y_fused.reshape(n_fused, n_orb * n_orb)
    Y_flat = np.ascontiguousarray(Y_flat)

    # Step 5: Gradient contract.
    # ∂E_K/∂R = +α_HF Σ_PQ ω_PQ ∂M_PQ/∂R − 2 α_HF Σ_{P,μν} Y^P_{μν} ∂(P|μν)/∂R
    grad_2c = np.asarray(
        compute_2c_eri_lattice_gradient_weighted(
            cache.fused_basis, system, lat_opts_2c, alpha_hf * omega_fused
        ),
        dtype=np.float64,
    )
    grad_3c = np.asarray(
        compute_3c_eri_lattice_gradient_weighted(
            ao_basis, cache.fused_basis, system, lat_opts_3c, Y_flat
        ),
        dtype=np.float64,
    )

    return grad_2c - 2.0 * alpha_hf * grad_3c


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


def _v_ne_gradient_numerical(system, basis, D, gauge_lat_opts):
    """Numerical FD of Tr(D.dV_ne/dR) using the SCF's FT-based Ewald V_ne.

    Uses compute_v_ne_ewald_3d_ft_gamma directly (same builder as the SCF)
    to compute V_ne at displaced geometries. The kinetic part is handled
    analytically by the caller. Slow but gauge-consistent.
    """
    import numpy as np

    from .periodic_v_ne import compute_v_ne_ewald_3d_ft_gamma

    n_atoms = len(system.unit_cell)
    delta = 3e-4
    grad = np.zeros((n_atoms, 3), dtype=np.float64)
    atoms = list(system.unit_cell)
    lattice = np.asarray(system.lattice, dtype=np.float64)

    for a in range(n_atoms):
        for d in range(3):
            e_vals = []
            for sign in [+1, -1]:
                atoms_d = []
                for i, atom in enumerate(atoms):
                    xyz = list(atom.xyz)
                    if i == a:
                        xyz[d] += sign * delta
                    atoms_d.append(type(atom)(int(atom.Z), xyz))
                from ._vibeqc_core import PeriodicSystem

                sys_d = PeriodicSystem(
                    int(system.dim),
                    lattice,
                    atoms_d,
                    charge=system.charge,
                    multiplicity=system.multiplicity,
                )
                basis_d = type(basis)(sys_d.unit_cell_molecule(), basis.name)
                V_d = compute_v_ne_ewald_3d_ft_gamma(
                    basis_d,
                    sys_d,
                    gauge_lat_opts,
                    ke_cutoff=200.0,
                )
                e_vals.append(float(np.einsum("ij,ij->", D, V_d)))
            grad[a, d] = (e_vals[0] - e_vals[1]) / (2.0 * delta)

    return grad


def compute_gdf_gradient_rhf_gamma(
    system: PeriodicSystem,
    basis: BasisSet,
    result,  # PBCGDFResult
    *,
    aux_basis: BasisSet,
    compcell_eta: float = 1.0,
    alpha_hf: float = 1.0,
    lattice_opts: Optional[LatticeSumOptions] = None,
    cache: Optional[_CompcellGradientCache] = None,
    madelung: float = 0.0,
    gauge_lat_opts: Optional[LatticeSumOptions] = None,
) -> Tuple[np.ndarray, _CompcellGradientCache]:
    """Analytic Γ-only GDF (compcell) RHF atomic gradient.

    Parameters
    ----------
    system, basis
        Periodic system and AO basis.
    result
        Converged :class:`PBCGDFResult` from :func:`vibeqc.run_pbc_gdf_rhf`.
    aux_basis
        Auxiliary basis (unscaled — the gradient pipeline handles modrho
        rescaling internally, matching the SCF).
    compcell_eta
        Smooth-Gaussian exponent for the compensating charges. Must match
        the value used in the SCF.
    alpha_hf
        HF-exchange fraction (1.0 for pure HF, 0.0 for pure DFT).
    lattice_opts
        Lattice-sum options. Must match the SCF. If None, uses
        :class:`LatticeSumOptions()`.
    cache
        Optional pre-built gradient cache. If None, built fresh.

    Returns
    -------
    grad : np.ndarray of shape (n_atoms, 3) in Hartree/bohr.
    cache : _CompcellGradientCache
        The cache built (or passed in). Caller can reuse across
        multiple gradient evaluations at the same geometry.
    """
    if lattice_opts is None:
        lattice_opts = LatticeSumOptions()
    if gauge_lat_opts is None:
        gauge_lat_opts = lattice_opts

    n_elec = system.n_electrons()
    if n_elec % 2 != 0:
        raise ValueError(
            "compute_gdf_gradient_rhf_gamma: closed-shell only "
            f"(got {n_elec} electrons)"
        )
    n_occ = n_elec // 2

    D = np.asarray(result.density, dtype=np.float64)
    C_occ = np.asarray(result.mo_coeffs[:, :n_occ], dtype=np.float64)
    n_atoms = len(system.unit_cell)
    n_orb = basis.nbasis

    # Build or reuse the geometry-dependent (density-independent) cache.
    if cache is None:
        # For now, use the same lat_opts for 2c and 3c (no auto-rcut).
        cache = _build_compcell_gradient_cache(
            system,
            basis,
            aux_basis,
            compcell_eta=float(compcell_eta),
            lat_opts_2c=lattice_opts,
            lat_opts_3c=lattice_opts,
        )

    # ---- 1-electron Pulay + overlap + nuclear terms ---------------------
    # These are the same as the existing periodic gradient (G1a) for the
    # Ewald path. We build the lattice-resolved density and W matrices
    # using the GDF SCF's hcore (not rebuilt molecule Fock, since at
    # Γ the GDF Fock doesn't have the G=0 gauge issue of EWALD_3D).

    # Build the overlap lattice for the cell list.
    S_lat = compute_overlap_lattice(basis, system, lattice_opts)

    # Lattice-resolved density: at Γ, D(g) = D_Γ for all images.
    D_gamma = np.asarray(D, dtype=np.float64)
    D_set = compute_overlap_lattice(basis, system, lattice_opts)
    # Homogeneous Γ density: D(g) = D_Γ for all image cells.
    # At Γ, k=0 Bloch phase is 1 in every cell, so the real-space
    # density is identical in all images. The SCF energy is built from
    # this homogeneous density; the gradient must use it too.
    home_cell_only = all(
        tuple(int(v) for v in np.asarray(c.index).reshape(3)) == (0, 0, 0)
        for c in D_set.cells
    )
    zero_block = np.zeros_like(D_gamma)
    for c_idx in range(len(D_set.cells)):
        idx = tuple(int(v) for v in np.asarray(D_set.cells[c_idx].index).reshape(3))
        D_set.set_block(
            c_idx, D_gamma if (idx == (0, 0, 0) or not home_cell_only) else zero_block
        )

    # Energy-weighted density W for the overlap-Lagrangian.
    # Use the converged Fock eigenvalues directly — the GDF route doesn't
    # have the EWALD_3D G=0 gauge shift, so ε from the SCF is fine.
    eps = np.asarray(result.mo_energies, dtype=np.float64)
    C = np.asarray(result.mo_coeffs, dtype=np.float64)
    W_gamma = 2.0 * (C[:, :n_occ] * eps[:n_occ][None, :]) @ C[:, :n_occ].T
    W_set = compute_overlap_lattice(basis, system, lattice_opts)
    for c_idx in range(len(W_set.cells)):
        idx = tuple(int(v) for v in np.asarray(W_set.cells[c_idx].index).reshape(3))
        W_set.set_block(
            c_idx,
            W_gamma if (idx == (0, 0, 0) or not home_cell_only) else zero_block,
        )

    grad = np.zeros((n_atoms, 3), dtype=np.float64)

    # Nuclear repulsion.
    grad += np.asarray(nuclear_repulsion_gradient_per_cell(system, gauge_lat_opts))

    # Overlap Lagrangian: -tr(W ∂S/∂R).
    grad += np.asarray(
        overlap_lattice_gradient_contribution(basis, system, W_set, lattice_opts)
    )

    # Kinetic + nuclear-attraction Pulay: tr(D ∂(T+V)/∂R).
    grad += np.asarray(
        kinetic_lattice_gradient_contribution(basis, system, D_set, lattice_opts)
    )
    # V_ne Pulay: numerical FD of SCF's FT-based Ewald V_ne
    # (gauge-consistent with the SCF).
    grad += _v_ne_gradient_numerical(system, basis, D, gauge_lat_opts)

    # ---- 2-electron DF gradient -----------------------------------------
    grad_J = _compute_j_gradient_compcell(
        system,
        basis,
        D,
        cache,
        lattice_opts,
        lattice_opts,
    )

    if alpha_hf > 0.0:
        grad_K = _compute_k_gradient_compcell(
            system,
            basis,
            D,
            C_occ,
            float(alpha_hf),
            cache,
            lattice_opts,
            lattice_opts,
        )
        # E_2e = E_J − ¼ α_HF tr(D·K_shifted)
        #      = E_J − ¼ α_HF tr(D·K_raw) + ¼ α_HF ξ tr(D·S·D·S)
        # ∂E_2e/∂R = ∂E_J/∂R − ¼ α_HF ∂E_K_raw/∂R + ½ α_HF ξ tr(D·S·D·∂S/∂R)
        grad += grad_J + grad_K

        # Exxdiv shift gradient: K_shifted = K_raw + ξ·S·D·S (ADDED in
        # apply_exxdiv_ewald_to_K). So E_K = -¼ Tr(D·K_raw) - ¼ ξ Tr(D·S·D·S).
        # ∂(-¼ ξ Tr(D·S·D·S))/∂R at fixed D = -½ ξ Tr(D·S·D·∂S/∂R).
        # overlap_lattice_gradient_contribution computes -Tr(W·∂S/∂R),
        # so pass W_exx = +½ ξ D·S·D.
        if abs(madelung) > 0.0:
            S = np.asarray(result.overlap, dtype=np.float64)
            DSD = D @ S @ D
            W_exx_gamma = 0.5 * madelung * DSD
            W_exx_set = compute_overlap_lattice(basis, system, lattice_opts)
            for c_idx in range(len(W_exx_set.cells)):
                idx = tuple(
                    int(v) for v in np.asarray(W_exx_set.cells[c_idx].index).reshape(3)
                )
                W_exx_set.set_block(
                    c_idx,
                    W_exx_gamma
                    if (idx == (0, 0, 0) or not home_cell_only)
                    else np.zeros_like(W_exx_gamma),
                )
            grad += np.asarray(
                overlap_lattice_gradient_contribution(
                    basis, system, W_exx_set, lattice_opts
                )
            )
    else:
        grad += grad_J

    return grad, cache


def compute_gdf_gradient(
    system: PeriodicSystem,
    basis: BasisSet,
    result,  # PBCGDFResult
    *,
    aux_basis: Optional[BasisSet] = None,
    aux_basis_name: str = "def2-svp-jk",
    compcell_eta: float = 1.0,
    lattice_opts: Optional[LatticeSumOptions] = None,
) -> np.ndarray:
    """Compute the GDF analytic gradient from a converged PBCGDFResult.

    Convenience wrapper around :func:`compute_gdf_gradient_rhf_gamma`.
    Handles gauge setup, gradient cache construction, and Madelung
    constant lookup automatically.

    Parameters
    ----------
    system, basis
        Periodic system and AO basis.
    result
        Converged :class:`PBCGDFResult` from :func:`vibeqc.run_pbc_gdf_rhf`.
    aux_basis
        Auxiliary basis. If None, built from ``aux_basis_name``.
    aux_basis_name
        Aux basis name for auto-construction.
    compcell_eta
        Must match the SCF value.
    lattice_opts
        Lattice-sum options. If None, uses LatticeSumOptions().

    Returns
    -------
    grad : (n_atoms, 3) ndarray in Hartree/bohr.
    """
    if not result.converged:
        raise ValueError("compute_gdf_gradient: SCF result is not converged")

    if lattice_opts is None:
        lattice_opts = LatticeSumOptions()

    from .pbc_gdf import _gauge_lat_opts_ewald_3d
    from .madelung import madelung_constant_for_cell
    from .aux_basis import make_aux_basis_set

    gauge_opts = _gauge_lat_opts_ewald_3d(lattice_opts, system)
    xi = madelung_constant_for_cell(system)
    mol = system.unit_cell_molecule()

    if aux_basis is None:
        aux_basis = make_aux_basis_set(mol, aux_name=aux_basis_name)

    # Build gradient cache (expensive, geometry-dependent, density-independent)
    cache = _build_compcell_gradient_cache(
        system, basis, aux_basis,
        compcell_eta=float(compcell_eta),
        lat_opts_2c=lattice_opts,
        lat_opts_3c=lattice_opts,
    )

    grad, _ = compute_gdf_gradient_rhf_gamma(
        system, basis, result,
        aux_basis=aux_basis,
        compcell_eta=float(compcell_eta),
        alpha_hf=1.0,
        lattice_opts=lattice_opts,
        cache=cache,
        madelung=float(xi),
        gauge_lat_opts=gauge_opts,
    )
    return grad
