Source code for dynsight._internal.lens.lens

"""Compute LENS for each atom in a trajectory."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from MDAnalysis import AtomGroup, Universe
    from numpy.typing import NDArray


import numba
import numpy as np
from numba import njit, prange


@njit(cache=True, fastmath=True)  # type: ignore[untyped-decorator]
def _pbc_diff(
    dx: NDArray[np.float64],
    box: NDArray[np.float64],
) -> NDArray[np.float64]:
    """Find distances in PBC box."""
    for k in range(3):
        if box[k] > 0.0:
            dx[k] -= np.rint(dx[k] / box[k]) * box[k]
    return dx


@njit(cache=True, fastmath=True)  # type: ignore[untyped-decorator]
def build_cell_list(
    positions: NDArray[np.float64],
    box: NDArray[np.float64],
    cell_size: float,
) -> tuple[
    NDArray[np.int32],
    NDArray[np.int32],
    NDArray[np.int32],
    NDArray[np.int32],
]:
    """Build a 3D periodic cell list."""
    n_atoms = positions.shape[0]
    min_cells = 3
    ncellx = max(min_cells, int(box[0] // cell_size))
    ncelly = max(min_cells, int(box[1] // cell_size))
    ncellz = max(min_cells, int(box[2] // cell_size))
    n_cells = ncellx * ncelly * ncellz

    head = np.full(n_cells, -1, dtype=np.int32)
    next_ = np.full(n_atoms, -1, dtype=np.int32)
    cell_ids = np.empty((n_atoms, 3), dtype=np.int32)

    for i in range(n_atoms):
        cx = int(positions[i, 0] / box[0] * ncellx) % ncellx
        cy = int(positions[i, 1] / box[1] * ncelly) % ncelly
        cz = int(positions[i, 2] / box[2] * ncellz) % ncellz
        cell_ids[i, 0] = cx
        cell_ids[i, 1] = cy
        cell_ids[i, 2] = cz
        cindex = cx * ncelly * ncellz + cy * ncellz + cz
        next_[i] = head[cindex]
        head[cindex] = i
    n_cell = np.array([ncellx, ncelly, ncellz])

    return cell_ids, head, next_, n_cell


# We need a function this complex and deep for numba to work
# This is why we are ignoring ruff complaints C901, PLR0912
@njit(cache=True, fastmath=True, parallel=True)  # type: ignore[untyped-decorator]
def neighbor_list_celllist_centers(  # noqa: C901, PLR0912
    positions_env: NDArray[np.float64],
    positions_cent: NDArray[np.float64],
    r_cut: float,
    box: NDArray[np.float64],
    respect_pbc: bool,
) -> tuple[NDArray[np.int32], NDArray[np.int32]]:
    """Build a CSR neighbor list *only for the centers*."""
    n_cent = positions_cent.shape[0]
    r_cut2 = (r_cut - 1e-6) ** 2

    _, head, next_, n_cell = build_cell_list(positions_env, box, r_cut)
    nx, ny, nz = n_cell

    n_neigh = np.zeros(n_cent, dtype=np.int32)
    # ---- count the neighbors for each center ----
    for i in prange(n_cent):
        cx = int(positions_cent[i, 0] / box[0] * nx) % nx
        cy = int(positions_cent[i, 1] / box[1] * ny) % ny
        cz = int(positions_cent[i, 2] / box[2] * nz) % nz

        for dx in (-1, 0, 1):
            for dy in (-1, 0, 1):
                for dz in (-1, 0, 1):
                    nx_ = (cx + dx) % nx
                    ny_ = (cy + dy) % ny
                    nz_ = (cz + dz) % nz
                    cidx = nx_ * ny * nz + ny_ * nz + nz_
                    j = head[cidx]
                    while j != -1:
                        dr = positions_env[j] - positions_cent[i]
                        if respect_pbc:
                            dr = _pbc_diff(dr, box)
                        dr2 = dr[0] ** 2 + dr[1] ** 2 + dr[2] ** 2
                        if dr2 > 0.0 and dr2 < r_cut2:
                            n_neigh[i] += 1
                        j = next_[j]

    indptr = np.empty(n_cent + 1, dtype=np.int32)
    indptr[0] = 0
    for i in range(n_cent):
        indptr[i + 1] = indptr[i] + n_neigh[i]

    indices = np.empty(indptr[-1], dtype=np.int32)
    cursor = np.zeros(n_cent, dtype=np.int32)

    # ---- fill up neighbors' lists ----
    for i in prange(n_cent):
        cx = int(positions_cent[i, 0] / box[0] * nx) % nx
        cy = int(positions_cent[i, 1] / box[1] * ny) % ny
        cz = int(positions_cent[i, 2] / box[2] * nz) % nz

        for dx in (-1, 0, 1):
            for dy in (-1, 0, 1):
                for dz in (-1, 0, 1):
                    nx_ = (cx + dx) % nx
                    ny_ = (cy + dy) % ny
                    nz_ = (cz + dz) % nz
                    cidx = nx_ * ny * nz + ny_ * nz + nz_
                    j = head[cidx]
                    while j != -1:
                        dr = positions_env[j] - positions_cent[i]
                        if respect_pbc:
                            dr = _pbc_diff(dr, box)
                        dr2 = dr[0] ** 2 + dr[1] ** 2 + dr[2] ** 2
                        if dr2 > 0.0 and dr2 < r_cut2:
                            pos_i = indptr[i] + cursor[i]
                            indices[pos_i] = j
                            cursor[i] += 1
                        j = next_[j]

    # sort each list (needed for computing intersections)
    for i in range(n_cent):
        s, e = indptr[i], indptr[i + 1]
        if e > s + 1:
            indices[s:e].sort()

    return indptr, indices


@njit(cache=True, fastmath=True)  # type: ignore[untyped-decorator]
def lens_from_two_csr(
    indptr1: NDArray[np.int32],
    indices1: NDArray[np.int32],
    indptr2: NDArray[np.int32],
    indices2: NDArray[np.int32],
) -> NDArray[np.float64]:
    """Return LENS distance between two neighbor lists.

    Note: CSR lists do NOT include the particle istelf.
    """
    n_centers = len(indptr1) - 1
    out = np.zeros(n_centers, dtype=np.float64)
    for u in range(n_centers):
        s1, e1 = indptr1[u], indptr1[u + 1]
        s2, e2 = indptr2[u], indptr2[u + 1]
        a = e1 - s1
        b = e2 - s2
        denom = a + b  # -2
        if denom <= 0:
            out[u] = 0.0
            continue

        i, j = s1, s2
        inter = 0
        while i < e1 and j < e2:
            vi = indices1[i]
            vj = indices2[j]
            if vi == vj:
                inter += 1
                i += 1
                j += 1
            elif vi < vj:
                i += 1
            else:
                j += 1

        numer = (a + b) - 2 * inter
        out[u] = numer / denom
    return out


[docs] def compute_lens( universe: Universe, r_cut: float, delay: int = 1, centers: str = "all", selection: str = "all", trajslice: slice | None = None, respect_pbc: bool = True, n_jobs: int = 1, ) -> NDArray[np.float64]: r"""Compute the LENS descriptor for all frames along a trajectory. LENS was developed by Martina Crippa, see https://doi.org/10.1073/pnas.2300565120. The current implementation is mainly due to @SimoneMartino98. .. warning:: The LENS functions only work with orthogonal simulation boxes. We are working to make them compatible with non-orthogonal ones. The LENS value of a particle between two frames is deined as: .. math:: LENS(t, t + \delta t) = \frac{|C(t)\cup C(t+\delta t)| - |C(t)\cap C(t+\delta t)|} {|C(t)| + |C(t+\delta t)|} where :math:`C(t)` and :math:`C(t+\delta t)` are the neighbors' list of the particle at frames :math:`t` and :math:`t+\delta t` respectively. Parameters: universe : MDAnalysis Universe containing the trajectory. r_cut : r_cut distance (Å) for defining neighbors. delay : Number of frames separating the pairs for comparison. centers : Atom selection string for the centers where LENS is computed. selection : Atom selection string defining the environment. trajslice : Frame slicing parameters for trajectory iteration. respect_pbc : Whether to apply periodic boundary conditions. n_jobs : The number of jobs for parallelization with numba. Returns: LENS values for each center and each pair of frames. Has shape (n_centers, n_pairs) """ numba.set_num_threads(n_jobs) ag_env = universe.select_atoms(selection) ag_cent = universe.select_atoms(centers) if trajslice is not None: fr_idx = list(range(universe.trajectory.n_frames))[trajslice] else: fr_idx = list(range(universe.trajectory.n_frames)) pairs = [ (fr_idx[i], fr_idx[i + delay]) for i in range(len(fr_idx) - delay) ] if not pairs: msg = "No valid pairs found." raise RuntimeError(msg) lens_array = np.zeros((ag_cent.n_atoms, len(pairs)), dtype=np.float64) for k, (t1, t2) in enumerate(pairs): # ---- frame t1 ---- universe.trajectory[t1] pos_env1 = ag_env.positions.astype(np.float64) pos_cent1 = ag_cent.positions.astype(np.float64) if universe.trajectory.ts.dimensions is not None: box = universe.trajectory.ts.dimensions[:3] else: coords = universe.atoms.positions mins = coords.min(axis=0) maxs = coords.max(axis=0) box = (maxs - mins) * 1.01 indptr_t1, indices_t1 = neighbor_list_celllist_centers( positions_env=pos_env1, positions_cent=pos_cent1, r_cut=r_cut, box=box, respect_pbc=respect_pbc, ) # ---- frame t2 ---- universe.trajectory[t2] pos_env2 = ag_env.positions.astype(np.float64) pos_cent2 = ag_cent.positions.astype(np.float64) if universe.trajectory.ts.dimensions is not None: box = universe.trajectory.ts.dimensions[:3] else: coords = universe.atoms.positions mins = coords.min(axis=0) maxs = coords.max(axis=0) box = (maxs - mins) * 1.01 indptr_t2, indices_t2 = neighbor_list_celllist_centers( positions_env=pos_env2, positions_cent=pos_cent2, r_cut=r_cut, box=box, respect_pbc=respect_pbc, ) # ---- LENS ---- lens_array[:, k] = lens_from_two_csr( indptr1=indptr_t1, indices1=indices_t1, indptr2=indptr_t2, indices2=indices_t2, ) return lens_array
[docs] def list_neighbours_along_trajectory( universe: Universe, r_cut: float, centers: str = "all", selection: str = "all", trajslice: slice | None = None, respect_pbc: bool = True, n_jobs: int = 1, ) -> list[list[AtomGroup]]: """Produce a per-frame list of neighbors. .. warning:: The LENS functions only work with orthogonal simulation boxes. We are working to make them compatible with non-orthogonal ones. Parameters: universe : MDAnalysis Universe containing the trajectory. r_cut : r_cut distance (Å) for defining neighbors. centers : Atom selection string for the centers where LENS is computed. selection : Atom selection string defining the environment. trajslice : Frame slicing parameters for trajectory iteration. respect_pbc : Whether to apply periodic boundary conditions. n_jobs : The number of jobs for parallelization with numba. Returns: List of frames, each frame a list of AtomGroups for each atom. """ if trajslice is None: trajslice = slice(None) ag_centers = universe.select_atoms(centers) ag_env = universe.select_atoms(selection) frame_indices = list( range(*trajslice.indices(universe.trajectory.n_frames)) ) numba.set_num_threads(n_jobs) def _compute_frame_neighbors(frame_idx: int) -> list[AtomGroup]: universe.trajectory[frame_idx] pos_env = ag_env.positions.astype(np.float64) pos_cent = ag_centers.positions.astype(np.float64) if universe.trajectory.ts.dimensions is not None: box = universe.trajectory.ts.dimensions[:3].astype(np.float64) local_pbc = respect_pbc else: # No periodic box: shift positions to [0, span] and disable PBC all_pos = np.vstack([pos_env, pos_cent]) mins = all_pos.min(axis=0) maxs = all_pos.max(axis=0) box = (maxs - mins) + 2 * r_cut pos_env = pos_env - mins pos_cent = pos_cent - mins local_pbc = False # --- Build neighbor list (CSR form) --- indptr, indices = neighbor_list_celllist_centers( positions_env=pos_env, positions_cent=pos_cent, r_cut=r_cut, box=box, respect_pbc=local_pbc, ) # --- Reconstruct AtomGroups per center --- frame_neighbors: list[AtomGroup] = [] for i in range(ag_centers.n_atoms): start, end = indptr[i], indptr[i + 1] neighbor_atoms = ag_env[indices[start:end]] frame_neighbors.append(neighbor_atoms) return frame_neighbors return [_compute_frame_neighbors(frame) for frame in frame_indices]