"""NH₃ umbrella inversion via climbing-image Nudged Elastic Band.

Run:
    .venv/bin/python input-nh3-umbrella-neb.py

Produces:
    output-nh3-umbrella-neb.traj   — full reaction path; `ase gui` animates it
    output-nh3-umbrella-neb.out    — text summary + MEP energies + TS geometry

The ammonia molecule inverts by flipping its three hydrogens through a
planar (D₃h) transition state. This script:

  1. Builds pyramidal NH₃ in two enantiomers (H's below vs. above N).
  2. Relaxes both endpoints with BFGSLineSearch.
  3. Sets up a 5-intermediate NEB band between them.
  4. Runs climbing-image NEB with FIRE to locate the TS.
  5. Writes the MEP energies, the TS geometry, and the trajectory.

The experimental inversion barrier is ~5.8 kcal/mol. HF/STO-3G
over-estimates at ~11 kcal/mol; this is the usual HF-with-tiny-basis
overshoot and is a known qualitative-only limit. For quantitative
agreement use MP2 or a hybrid DFT functional in a def2-TZVP basis.
"""

from pathlib import Path

import numpy as np
from ase import Atoms
from ase.io import write as ase_write
from ase.mep import NEB
from ase.optimize import BFGSLineSearch, FIRE

from vibeqc.ase import VibeQC

HERE = Path(__file__).parent
TRAJ_OUT = HERE / "output-nh3-umbrella-neb.traj"
TEXT_OUT = HERE / "output-nh3-umbrella-neb.out"


# --- geometry builders ----------------------------------------------------

def build_nh3(flip: bool = False) -> Atoms:
    """Pyramidal NH₃ with H atoms below (``flip=False``) or above N."""
    theta = np.deg2rad(106.67)      # HNH angle, experimental
    r_nh = 1.012                     # N-H bond length, experimental (Å)

    # Geometry: N at origin, three H atoms in a trigonal pyramid.
    sin_alpha = np.sqrt(2.0 * (1.0 - np.cos(theta)) / 3.0)
    cos_alpha = np.sqrt(1.0 - sin_alpha ** 2)
    z_sign = +1 if flip else -1

    positions = [[0.0, 0.0, 0.0]]                # nitrogen
    for k in range(3):
        phi = k * 2.0 * np.pi / 3.0
        positions.append([
            r_nh * sin_alpha * np.cos(phi),
            r_nh * sin_alpha * np.sin(phi),
            z_sign * r_nh * cos_alpha,
        ])

    atoms = Atoms("NH3", positions=positions)
    atoms.calc = VibeQC(basis="sto-3g")
    return atoms


# --- run ------------------------------------------------------------------

# 1. Relaxed endpoints.
initial = build_nh3(flip=False)
BFGSLineSearch(initial, logfile=None).run(fmax=0.05, steps=50)
final = build_nh3(flip=True)
BFGSLineSearch(final, logfile=None).run(fmax=0.05, steps=50)

e_initial = initial.get_potential_energy()
e_final = final.get_potential_energy()

# 2. Build a 5-image interior. Each intermediate has its own calculator.
N_INTERIOR = 5
images = [initial]
for _ in range(N_INTERIOR):
    img = initial.copy()
    img.calc = VibeQC(basis="sto-3g")
    images.append(img)
images.append(final)

# 3. Linear interpolation across the band, then climbing-image NEB.
neb = NEB(images, k=0.1, climb=True, method="improvedtangent")
neb.interpolate()

# FIRE (damped dynamics) is more robust than BFGS for the non-conservative
# NEB force field; BFGS's energy-decrease line search doesn't fit the NEB
# band's gradient projection.
FIRE(neb, logfile=None).run(fmax=0.05, steps=200)

# 4. Report.
ase_write(str(TRAJ_OUT), images)

kcal_per_eV = 23.0605
with open(TEXT_OUT, "w") as f:
    f.write("NH3 umbrella inversion via CI-NEB @ HF/STO-3G\n")
    f.write("=" * 50 + "\n\n")
    f.write("Minimum energy path:\n")
    f.write(f"  {'image':>6s}  {'ΔE (eV)':>10s}  {'ΔE (kcal/mol)':>14s}\n")
    for i, img in enumerate(images):
        de = img.get_potential_energy() - e_initial
        f.write(f"  {i:6d}  {de:10.4f}  {de * kcal_per_eV:14.3f}\n")

    e_max = max(img.get_potential_energy() for img in images)
    barrier_ev = e_max - e_initial
    ts_index = int(np.argmax([img.get_potential_energy() for img in images]))

    f.write("\n")
    f.write(f"Transition state: image {ts_index}\n")
    f.write(f"Barrier: {barrier_ev:.4f} eV  ({barrier_ev * kcal_per_eV:.2f} kcal/mol)\n")
    f.write(f"Experimental barrier: ~5.8 kcal/mol\n")
    f.write("HF/STO-3G over-estimates by ~2× — expected for this level of theory.\n\n")

    f.write(f"TS geometry (image {ts_index}, Å):\n")
    for i, p in enumerate(images[ts_index].positions):
        sym = "N" if i == 0 else "H"
        f.write(f"  {sym}   {p[0]:+.5f}  {p[1]:+.5f}  {p[2]:+.5f}\n")

print(f"Wrote {TRAJ_OUT.name}  and  {TEXT_OUT.name}")
