Source code for biomol.core.biomol

# pyright: reportImportCycles=none

from typing import Any, Generic

import numpy as np
from numpy.typing import NDArray
from typing_extensions import Self

from biomol.enums import StructureLevel
from biomol.exceptions import IndexMismatchError, StructureLevelError

from .container import FeatureContainer
from .feature import Feature
from .index import IndexTable
from .types import BioMolDict
from .utils import load_bytes, to_bytes
from .view import A_co, AtomView, C_co, ChainView, R_co, ResidueView


[docs] class BioMol(Generic[A_co, R_co, C_co]): """A class representing a biomolecular structure. Parameters ---------- atom_container: FeatureContainer The container holding atom-level features. residue_container: FeatureContainer The container holding residue-level features. chain_container: FeatureContainer The container holding chain-level features. index_table: IndexTable The index table mapping atoms, residues, and chains. metadata: dict[str, Any] | None, optional Additional metadata associated with the biomolecular structure. """ def __init__( self, atom_container: FeatureContainer, residue_container: FeatureContainer, chain_container: FeatureContainer, index_table: IndexTable, metadata: dict[str, Any] | None = None, ) -> None: self._atom_container = atom_container self._residue_container = residue_container self._chain_container = chain_container self._index = index_table self._metadata = metadata or {} self._check_lengths() @property def atoms(self) -> A_co: """View of the atoms in the selection.""" return AtomView(self, np.arange(len(self._atom_container))) # pyright: ignore[reportReturnType] @property def residues(self) -> R_co: """View of the residues in the selection.""" return ResidueView(self, np.arange(len(self._residue_container))) # pyright: ignore[reportReturnType] @property def chains(self) -> C_co: """View of the chains in the selection.""" return ChainView(self, np.arange(len(self._chain_container))) # pyright: ignore[reportReturnType] @property def index_table(self) -> IndexTable: """The index table mapping atoms, residues, and chains.""" return self._index @property def metadata(self) -> dict[str, Any]: """The metadata associated with the biomolecular structure.""" return self._metadata
[docs] def get_container(self, level: StructureLevel) -> FeatureContainer: """Get the feature container for a specific structure level. Parameters ---------- level: StructureLevel The structure level for which to get the feature container. Returns ------- FeatureContainer The feature container for the specified structure level. """ match level: case StructureLevel.ATOM: return self._atom_container case StructureLevel.RESIDUE: return self._residue_container case StructureLevel.CHAIN: return self._chain_container case _: msg = f"Invalid structure level: {level}." raise StructureLevelError(msg)
[docs] def to_dict(self) -> BioMolDict: """Convert the BioMol object to a dictionary.""" return { "atoms": self._atom_container.to_dict(), "residues": self._residue_container.to_dict(), "chains": self._chain_container.to_dict(), "index_table": self._index.to_dict(), "metadata": self._metadata, }
[docs] @classmethod def from_dict(cls, data: BioMolDict) -> Self: """Create a BioMol object from a dictionary. Parameters ---------- data: BioMolDict A dictionary containing the data to create the BioMol object. Returns ------- BioMol The created BioMol object. """ return cls( FeatureContainer.from_dict(data["atoms"]), FeatureContainer.from_dict(data["residues"]), FeatureContainer.from_dict(data["chains"]), IndexTable.from_dict(data["index_table"]), data["metadata"], )
[docs] def to_bytes(self, level: int = 6) -> bytes: """Serialize the container to zstd-compressed bytes. Parameters ---------- level: int, optional The compression level for zstd (default is 6). """ return to_bytes(self.to_dict(), level=level)
[docs] @classmethod def from_bytes(cls, byte_data: bytes) -> Self: """Deserialize the container from zstd-compressed bytes.""" return cls.from_dict(load_bytes(byte_data)) # pyright: ignore[reportArgumentType]
[docs] def update_features( self, level: StructureLevel, **features: Feature | NDArray[Any], ) -> Self: """Update features at a specific structure level. Parameters ---------- level: StructureLevel The structure level at which to update features. **features: Feature | NDArray[Any] Key-value pairs of features to update. Values can be either Feature objects or numpy arrays (which will be converted to NodeFeature). Returns ------- mol Updated BioMol object. Notes ----- Does not modify the current BioMol instance; instead, returns a new one. Examples -------- .. code-block:: python mol = BioMol(...) new_mol = mol.update_features( StructureLevel.ATOM, coord=mol.atoms.coord + 1.0, ) """ containers = { StructureLevel.ATOM: self._atom_container, StructureLevel.RESIDUE: self._residue_container, StructureLevel.CHAIN: self._chain_container, } containers[level] = containers[level].update(**features) return self.__class__( containers[StructureLevel.ATOM], containers[StructureLevel.RESIDUE], containers[StructureLevel.CHAIN], self.index_table, self.metadata, )
[docs] def remove_features(self, level: StructureLevel, *keys: str) -> Self: """Remove features at a specific structure level. Parameters ---------- level: StructureLevel The structure level at which to remove features. *keys: str Keys of the features to remove. Returns ------- mol Updated BioMol object. Notes ----- Does not modify the current BioMol instance; instead, returns a new one. Examples -------- .. code-block:: python mol = BioMol(...) new_mol = mol.remove_features(StructureLevel.ATOM, "coord", "element") """ containers = { StructureLevel.ATOM: self._atom_container, StructureLevel.RESIDUE: self._residue_container, StructureLevel.CHAIN: self._chain_container, } containers[level] = containers[level].remove(*keys) return self.__class__( containers[StructureLevel.ATOM], containers[StructureLevel.RESIDUE], containers[StructureLevel.CHAIN], self.index_table, self.metadata, )
[docs] @classmethod def concat(cls, mols: list[Self]) -> Self: """Concatenate multiple BioMol objects. Parameters ---------- mols: list[Self] List of BioMol objects to concatenate. Returns ------- Self Concatenated BioMol object. Notes ----- All containers must have the same set of feature keys. Metadata from the first BioMol object is retained. Always returns a new BioMol instance, even if only one object is provided. Examples -------- .. code-block:: python mol1 = BioMol(...) mol2 = BioMol(...) concatenated_mol = BioMol.concat([mol1, mol2]) """ if not mols: msg = "Cannot concatenate an empty list of BioMol objects." raise ValueError(msg) atom_containers = [mol.get_container(StructureLevel.ATOM) for mol in mols] residue_containers = [mol.get_container(StructureLevel.RESIDUE) for mol in mols] chain_containers = [mol.get_container(StructureLevel.CHAIN) for mol in mols] residue_counts = [len(container) for container in residue_containers] residue_offsets = np.cumsum([0, *residue_counts[:-1]]) atom_to_res = [ mol.index_table.atom_to_res + offset for mol, offset in zip(mols, residue_offsets, strict=True) ] chain_counts = [len(container) for container in chain_containers] chain_offsets = np.cumsum([0, *chain_counts[:-1]]) res_to_chain = [ mol.index_table.res_to_chain + offset for mol, offset in zip(mols, chain_offsets, strict=True) ] concatenated_table = IndexTable.from_parents( atom_to_res=np.concatenate(atom_to_res, axis=0), res_to_chain=np.concatenate(res_to_chain, axis=0), n_chain=sum(chain_counts), ) return cls( FeatureContainer.concat(atom_containers), FeatureContainer.concat(residue_containers), FeatureContainer.concat(chain_containers), concatenated_table, metadata=mols[0].metadata.copy(), )
[docs] def copy(self) -> Self: """Create a deep copy of the BioMol.""" return self.__class__( self._atom_container.copy(), self._residue_container.copy(), self._chain_container.copy(), self._index.copy(), self._metadata.copy(), )
def _check_lengths(self) -> None: """Check if the lengths of the containers and index table are consistent.""" if len(self._atom_container) != len(self._index.atom_to_res): msg = ( "Atom length mismatch: " f"atom_container has length {len(self._atom_container)}, " f"but index table has length {len(self._index.atom_to_res)}." ) raise IndexMismatchError(msg) if len(self._residue_container) != len(self._index.res_to_chain): msg = ( "Residue length mismatch: " f"residue_container has length {len(self._residue_container)}, " f"but index table has length {len(self._index.res_to_chain)}." ) raise IndexMismatchError(msg) if len(self._chain_container) != len(self._index.chain_res_indptr) - 1: msg = ( "Chain length mismatch: " f"chain_container has length {len(self._chain_container)}, " f"but index table has length {len(self._index.chain_res_indptr) - 1}." ) raise IndexMismatchError(msg) def __repr__(self) -> str: """Return a string representation of the BioMol object.""" return ( f"<{self.__class__.__name__} with {len(self._atom_container)} atoms, " f"{len(self._residue_container)} residues, " f"and {len(self._chain_container)} chains>" ) def __add__(self, other: Self) -> Self: """Concatenate two BioMol objects using the + operator. Note ---- For concatenating more than two objects, use BioMol.concat([mol1, mol2, mol3]) for better performance. """ return self.concat([self, other])