Source code for vibeqc.periodic_symmetrize

"""Periodic-structure space-group detection and symmetrisation helpers."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Iterator

import numpy as np

from ._vibeqc_core import (
    Atom,
    Crystal,
    PeriodicSystem,
    SpaceGroup,
    analyze_symmetry,
    attach_symmetry,
)


[docs] @dataclass(frozen=True) class SymmetriseReport: """Metadata describing a periodic-structure symmetrisation.""" symprec: float spacegroup_before: SpaceGroup spacegroup_after: SpaceGroup n_atoms_before: int n_atoms_after: int volume_before_bohr3: float volume_after_bohr3: float rms_displacement_bohr: float max_displacement_bohr: float to_primitive: bool idealized: bool
[docs] @dataclass(frozen=True) class SymmetriseResult: """Return value of :func:`symmetrise`.""" system: PeriodicSystem report: SymmetriseReport def __iter__(self) -> Iterator[PeriodicSystem | SymmetriseReport]: """Allow ``system, report = symmetrise(...)`` unpacking.""" yield self.system yield self.report
[docs] def detect_spacegroup( system: PeriodicSystem | Crystal, symprec: float = 1.0e-4, ) -> SpaceGroup: """Return spglib's space-group analysis without mutating ``system``. ``attach_symmetry(system)`` remains the calculation-side helper when downstream code should see ``system.symmetry``. This function is for introspection and reports. """ crystal = _as_crystal(system) return analyze_symmetry(crystal, symprec=symprec)
[docs] def symmetrise( system: PeriodicSystem, symprec: float = 1.0e-4, *, to_primitive: bool = False, idealize: bool = True, ) -> SymmetriseResult: """Standardise and idealise a 3D periodic structure via spglib. Returns a new :class:`PeriodicSystem` and a report containing the before/after space groups plus the RMS and maximum atomic displacement between spglib's non-idealized and idealized standardized cells. The input system is not modified. """ if int(system.dim) != 3: raise ValueError( "symmetrise currently supports only 3D bulk PeriodicSystem " f"objects; got dim={system.dim}" ) before = detect_spacegroup(system, symprec=symprec) cell = _system_to_spglib_cell(system) raw = _standardize_cell( cell, to_primitive=to_primitive, no_idealize=True, symprec=symprec, ) standardized = ( _standardize_cell( cell, to_primitive=to_primitive, no_idealize=not idealize, symprec=symprec, ) if idealize else raw ) out = _spglib_cell_to_system( standardized, charge=system.charge, multiplicity=system.multiplicity, ) attach_symmetry(out, symprec=symprec) after = out.symmetry if after is None: raise RuntimeError("symmetrise: attach_symmetry did not populate output") rms, max_disp = _matched_displacement_stats(raw, standardized) report = SymmetriseReport( symprec=float(symprec), spacegroup_before=before, spacegroup_after=after, n_atoms_before=len(system.unit_cell), n_atoms_after=len(out.unit_cell), volume_before_bohr3=float(abs(np.linalg.det(np.asarray(system.lattice)))), volume_after_bohr3=float(abs(np.linalg.det(np.asarray(out.lattice)))), rms_displacement_bohr=float(rms), max_displacement_bohr=float(max_disp), to_primitive=bool(to_primitive), idealized=bool(idealize), ) return SymmetriseResult(system=out, report=report)
def _as_crystal(system: PeriodicSystem | Crystal) -> Crystal: if isinstance(system, Crystal): return system if isinstance(system, PeriodicSystem): return _system_to_crystal(system) raise TypeError( "detect_spacegroup expects a PeriodicSystem or Crystal, " f"got {type(system).__name__}" ) def _system_to_crystal(system: PeriodicSystem) -> Crystal: lattice = np.asarray(system.lattice, dtype=float, order="F") inv_lattice = np.linalg.inv(lattice) frac = np.empty((3, len(system.unit_cell)), dtype=float, order="F") species: list[int] = [] for i, atom in enumerate(system.unit_cell): frac[:, i] = inv_lattice @ np.asarray(atom.xyz, dtype=float) species.append(int(atom.Z)) return Crystal(lattice, frac, species) def _system_to_spglib_cell( system: PeriodicSystem, ) -> tuple[np.ndarray, np.ndarray, list[int]]: crystal = _system_to_crystal(system) return _crystal_to_spglib_cell(crystal) def _crystal_to_spglib_cell( crystal: Crystal, ) -> tuple[np.ndarray, np.ndarray, list[int]]: lattice_rows = np.asarray(crystal.lattice, dtype=float).T positions = np.asarray(crystal.fractional_coords, dtype=float).T % 1.0 numbers = [int(z) for z in crystal.species] return lattice_rows, positions, numbers def _standardize_cell( cell: tuple[np.ndarray, np.ndarray, list[int]], *, to_primitive: bool, no_idealize: bool, symprec: float, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: try: import spglib except ImportError as exc: # pragma: no cover raise ImportError( "symmetrise requires the Python spglib package. " "Install vibe-qc with its runtime dependencies." ) from exc out = spglib.standardize_cell( cell, to_primitive=bool(to_primitive), no_idealize=bool(no_idealize), symprec=float(symprec), ) if out is None: try: detail = spglib.get_error_message() except TypeError: # pragma: no cover - older spglib signature detail = "" msg = "symmetrise: spglib.standardize_cell failed" if detail: msg += f" ({detail})" raise RuntimeError(msg) lattice_rows, positions, numbers = out return ( np.asarray(lattice_rows, dtype=float), np.asarray(positions, dtype=float) % 1.0, np.asarray(numbers, dtype=int), ) def _spglib_cell_to_system( cell: tuple[np.ndarray, np.ndarray, np.ndarray], *, charge: int, multiplicity: int, ) -> PeriodicSystem: lattice_rows, positions, numbers = cell lattice = np.asarray(lattice_rows, dtype=float).T atoms = [ Atom(int(z), (lattice @ np.asarray(frac, dtype=float)).tolist()) for frac, z in zip(np.asarray(positions, dtype=float), numbers) ] return PeriodicSystem( dim=3, lattice=np.asarray(lattice, dtype=float, order="F"), unit_cell=atoms, charge=int(charge), multiplicity=int(multiplicity), ) def _matched_displacement_stats( raw: tuple[np.ndarray, np.ndarray, np.ndarray], ideal: tuple[np.ndarray, np.ndarray, np.ndarray], ) -> tuple[float, float]: _raw_lattice, raw_pos, raw_numbers = raw ideal_lattice, ideal_pos, ideal_numbers = ideal if len(raw_numbers) != len(ideal_numbers): return float("nan"), float("nan") displacements: list[float] = [] for z in sorted(set(int(x) for x in ideal_numbers)): raw_idx = np.where(raw_numbers == z)[0] ideal_idx = np.where(ideal_numbers == z)[0] if len(raw_idx) != len(ideal_idx): return float("nan"), float("nan") if len(raw_idx) == 0: continue costs = np.empty((len(raw_idx), len(ideal_idx)), dtype=float) for i, iraw in enumerate(raw_idx): for j, jideal in enumerate(ideal_idx): costs[i, j] = _periodic_distance_bohr( raw_pos[iraw], ideal_pos[jideal], ideal_lattice, ) rows, cols = _linear_sum_assignment(costs) displacements.extend(float(costs[i, j]) for i, j in zip(rows, cols)) if not displacements: return 0.0, 0.0 arr = np.asarray(displacements, dtype=float) return float(np.sqrt(np.mean(arr * arr))), float(np.max(arr)) def _periodic_distance_bohr( frac_a: np.ndarray, frac_b: np.ndarray, lattice_rows: np.ndarray, ) -> float: delta = np.asarray(frac_b, dtype=float) - np.asarray(frac_a, dtype=float) delta -= np.rint(delta) cart = delta @ np.asarray(lattice_rows, dtype=float) return float(np.linalg.norm(cart)) def _linear_sum_assignment(costs: np.ndarray) -> tuple[np.ndarray, np.ndarray]: try: from scipy.optimize import linear_sum_assignment except ImportError: # pragma: no cover - scipy is a runtime dependency n = costs.shape[0] rows: list[int] = [] cols: list[int] = [] remaining = set(range(n)) for i in range(n): j = min(remaining, key=lambda col: costs[i, col]) rows.append(i) cols.append(j) remaining.remove(j) return np.asarray(rows), np.asarray(cols) return linear_sum_assignment(costs)