Source code for biomol.core.index

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar

import numpy as np

from biomol.enums import StructureLevel
from biomol.exceptions import (
    IndexInvalidError,
    IndexOutOfBoundsError,
    StructureLevelError,
)

if TYPE_CHECKING:
    from numpy.typing import NDArray

    from .types import IndexTableDict


def _build_csr(
    parent_of_child: NDArray[np.integer],
    n_parent: int,
) -> tuple[NDArray[np.integer], NDArray[np.integer]]:
    counts = np.bincount(parent_of_child, minlength=n_parent)
    indptr = np.empty(n_parent + 1, dtype=int)
    indptr[0] = 0
    np.cumsum(counts, dtype=int, out=indptr[1:])
    indices = np.empty_like(parent_of_child)
    offsets = indptr[:-1].copy()
    for child_idx, parent_idx in enumerate(parent_of_child):
        pos = offsets[parent_idx]
        indices[pos] = child_idx
        offsets[parent_idx] += 1
    return indptr, indices


[docs] @dataclass(frozen=True, slots=True) class IndexTable: """Index mapping between structural levels. This class stores forward parent mappings and reverse CSR mappings to efficiently move between atoms, residues, and chains. Parameters ---------- atom_to_res : NDArray[np.integer] 1D array mapping each atom index to its parent residue index. res_to_chain : NDArray[np.integer] 1D array mapping each residue index to its parent chain index. res_atom_indptr : NDArray[np.integer] CSR index pointer array for residues to atoms mapping. res_atom_indices : NDArray[np.integer] CSR indices array for residues to atoms mapping. chain_res_indptr : NDArray[np.integer] CSR index pointer array for chains to residues mapping. chain_res_indices : NDArray[np.integer] CSR indices array for chains to residues mapping. Examples -------- .. code-block:: python >>> table = IndexTable.from_parents( ... atom_to_res=np.array([0, 0, 1, 1, 2]), ... res_to_chain=np.array([0, 0, 1]), ... ) >>> table.atoms_to_residues(np.array([0, 2, 4])) array([0, 1, 2]) >>> table.residues_to_chains(np.array([0, 2])) array([0, 1]) >>> table.chains_to_residues(np.array([0, 1])) array([0, 1, 2]) """ atom_to_res: NDArray[np.integer] """1D array mapping each atom index to its parent residue index.""" res_to_chain: NDArray[np.integer] """1D array mapping each residue index to its parent chain index.""" res_atom_indptr: NDArray[np.integer] """CSR index pointer array for residues to atoms mapping.""" res_atom_indices: NDArray[np.integer] """CSR indices array for residues to atoms mapping.""" chain_res_indptr: NDArray[np.integer] """CSR index pointer array for chains to residues mapping.""" chain_res_indices: NDArray[np.integer] """CSR indices array for chains to residues mapping.""" _converter_table: ClassVar[dict[tuple[StructureLevel, StructureLevel], str]] = { (StructureLevel.ATOM, StructureLevel.RESIDUE): "atoms_to_residues", (StructureLevel.ATOM, StructureLevel.CHAIN): "atoms_to_chains", (StructureLevel.RESIDUE, StructureLevel.ATOM): "residues_to_atoms", (StructureLevel.RESIDUE, StructureLevel.CHAIN): "residues_to_chains", (StructureLevel.CHAIN, StructureLevel.RESIDUE): "chains_to_residues", (StructureLevel.CHAIN, StructureLevel.ATOM): "chains_to_atoms", }
[docs] @classmethod def from_parents( cls, atom_to_res: NDArray[np.integer], res_to_chain: NDArray[np.integer], n_chain: int | None = None, ) -> IndexTable: """Create IndexTable from forward parent mappings. Parameters ---------- atom_to_res : NDArray[np.integer] 1D array mapping each atom index to its parent residue index. res_to_chain : NDArray[np.integer] 1D array mapping each residue index to its parent chain index. n_chain : int | None, optional Total number of chains. If None, inferred as max(res_to_chain) + 1. """ cls._check_indices(atom_to_res) cls._check_indices(res_to_chain) n_residue = len(res_to_chain) if atom_to_res.max() >= n_residue: msg = ( f"atom_to_res has out-of-range values={atom_to_res.max()} " f"for {n_residue=}" ) raise IndexOutOfBoundsError(msg) if n_chain is None: n_chain = int(res_to_chain.max()) + 1 elif n_chain <= 0: msg = f"n_chain must be positive, got {n_chain}" raise IndexOutOfBoundsError(msg) elif res_to_chain.max() >= n_chain: msg = ( f"res_to_chains has out-of-range values={res_to_chain.max()} " f"for {n_chain=}" ) raise IndexOutOfBoundsError(msg) res_atom_indptr, res_atom_indices = _build_csr(atom_to_res, n_residue) chain_res_indptr, chain_res_indices = _build_csr(res_to_chain, n_chain) return cls( atom_to_res=atom_to_res, res_to_chain=res_to_chain, res_atom_indptr=res_atom_indptr, res_atom_indices=res_atom_indices, chain_res_indptr=chain_res_indptr, chain_res_indices=chain_res_indices, )
[docs] def atoms_to_residues(self, indices: NDArray[np.integer]) -> NDArray[np.integer]: """Map atom indices to residue indices.""" return self.atom_to_res[np.asarray(indices)]
[docs] def residues_to_chains(self, indices: NDArray[np.integer]) -> NDArray[np.integer]: """Map residue indices to chain indices.""" return self.res_to_chain[np.asarray(indices)]
[docs] def atoms_to_chains(self, indices: NDArray[np.integer]) -> NDArray[np.integer]: """Map atom indices to chain indices.""" res_indices = self.atoms_to_residues(indices) return self.residues_to_chains(res_indices)
[docs] def residues_to_atoms(self, indices: NDArray[np.integer]) -> NDArray[np.integer]: """Map residue indices to concatenated atom indices.""" indices = np.asarray(indices) if indices.size == 0: return indices parts = [ self.res_atom_indices[ self.res_atom_indptr[idx] : self.res_atom_indptr[idx + 1] ] for idx in indices ] if not parts: return np.empty((0,), dtype=self.res_atom_indices.dtype) return np.concatenate(parts)
[docs] def chains_to_residues(self, indices: NDArray[np.integer]) -> NDArray[np.integer]: """Map chain indices to concatenated residue indices.""" indices = np.asarray(indices) if indices.size == 0: return indices.astype(int) parts = [ self.chain_res_indices[ self.chain_res_indptr[idx] : self.chain_res_indptr[idx + 1] ] for idx in indices ] if not parts: return np.empty((0,), dtype=self.chain_res_indices.dtype) return np.concatenate(parts)
[docs] def chains_to_atoms(self, indices: NDArray[np.integer]) -> NDArray[np.integer]: """Map chain indices to concatenated atom indices.""" res_indices = self.chains_to_residues(indices) return self.residues_to_atoms(res_indices)
[docs] def convert( self, indices: NDArray[np.integer], source: StructureLevel, target: StructureLevel, ) -> NDArray[np.integer]: """Convert indices between structural levels. Parameters ---------- indices : NDArray[np.integer] 1D array of indices at the source level. source : StructureLevel The structural level of the input indices. target : StructureLevel The structural level to convert the indices to. Returns ------- NDArray[np.integer] 1D array of indices at the target level. """ if source == target: return indices if (source, target) not in self._converter_table: msg = f"Invalid level conversion: {source} -> {target}" raise StructureLevelError(msg) method_name = self._converter_table[(source, target)] return getattr(self, method_name)(indices)
@staticmethod def _check_indices(indices: NDArray[np.integer]) -> None: if indices.ndim != 1: msg = f"Indices must be 1D, got shape {indices.shape}" raise IndexInvalidError(msg) if indices.size == 0: msg = "Indices must be non-empty" raise IndexInvalidError(msg) if np.any(indices < 0): msg = "Indices contain negative values" raise IndexInvalidError(msg)
[docs] def copy(self) -> IndexTable: """Create a deep copy of the IndexTable.""" return IndexTable( self.atom_to_res.copy(), self.res_to_chain.copy(), self.res_atom_indptr.copy(), self.res_atom_indices.copy(), self.chain_res_indptr.copy(), self.chain_res_indices.copy(), )
[docs] def to_dict(self) -> IndexTableDict: """Convert IndexTable to a JSON-serializable dictionary.""" return { "atom_to_res": self.atom_to_res.tolist(), "res_to_chain": self.res_to_chain.tolist(), "res_atom_indptr": self.res_atom_indptr.tolist(), "res_atom_indices": self.res_atom_indices.tolist(), "chain_res_indptr": self.chain_res_indptr.tolist(), "chain_res_indices": self.chain_res_indices.tolist(), }
[docs] @classmethod def from_dict(cls, data: IndexTableDict) -> IndexTable: """Create IndexTable from a dictionary.""" return cls( atom_to_res=np.array(data["atom_to_res"], dtype=int), res_to_chain=np.array(data["res_to_chain"], dtype=int), res_atom_indptr=np.array(data["res_atom_indptr"], dtype=int), res_atom_indices=np.array(data["res_atom_indices"], dtype=int), chain_res_indptr=np.array(data["chain_res_indptr"], dtype=int), chain_res_indices=np.array(data["chain_res_indices"], dtype=int), )