# pyright: reportImportCycles=none
from typing import Any, Generic
import numpy as np
from numpy.typing import NDArray
from typing_extensions import Self
from biomol.enums import StructureLevel
from biomol.exceptions import IndexMismatchError, StructureLevelError
from .container import FeatureContainer
from .feature import Feature
from .index import IndexTable
from .types import BioMolDict
from .utils import load_bytes, to_bytes
from .view import A_co, AtomView, C_co, ChainView, R_co, ResidueView
[docs]
class BioMol(Generic[A_co, R_co, C_co]):
"""A class representing a biomolecular structure.
Parameters
----------
atom_container: FeatureContainer
The container holding atom-level features.
residue_container: FeatureContainer
The container holding residue-level features.
chain_container: FeatureContainer
The container holding chain-level features.
index_table: IndexTable
The index table mapping atoms, residues, and chains.
metadata: dict[str, Any] | None, optional
Additional metadata associated with the biomolecular structure.
"""
def __init__(
self,
atom_container: FeatureContainer,
residue_container: FeatureContainer,
chain_container: FeatureContainer,
index_table: IndexTable,
metadata: dict[str, Any] | None = None,
) -> None:
self._atom_container = atom_container
self._residue_container = residue_container
self._chain_container = chain_container
self._index = index_table
self._metadata = metadata or {}
self._check_lengths()
@property
def atoms(self) -> A_co:
"""View of the atoms in the selection."""
return AtomView(self, np.arange(len(self._atom_container))) # pyright: ignore[reportReturnType]
@property
def residues(self) -> R_co:
"""View of the residues in the selection."""
return ResidueView(self, np.arange(len(self._residue_container))) # pyright: ignore[reportReturnType]
@property
def chains(self) -> C_co:
"""View of the chains in the selection."""
return ChainView(self, np.arange(len(self._chain_container))) # pyright: ignore[reportReturnType]
@property
def index_table(self) -> IndexTable:
"""The index table mapping atoms, residues, and chains."""
return self._index
@property
def metadata(self) -> dict[str, Any]:
"""The metadata associated with the biomolecular structure."""
return self._metadata
[docs]
def get_container(self, level: StructureLevel) -> FeatureContainer:
"""Get the feature container for a specific structure level.
Parameters
----------
level: StructureLevel
The structure level for which to get the feature container.
Returns
-------
FeatureContainer
The feature container for the specified structure level.
"""
match level:
case StructureLevel.ATOM:
return self._atom_container
case StructureLevel.RESIDUE:
return self._residue_container
case StructureLevel.CHAIN:
return self._chain_container
case _:
msg = f"Invalid structure level: {level}."
raise StructureLevelError(msg)
[docs]
def to_dict(self) -> BioMolDict:
"""Convert the BioMol object to a dictionary."""
return {
"atoms": self._atom_container.to_dict(),
"residues": self._residue_container.to_dict(),
"chains": self._chain_container.to_dict(),
"index_table": self._index.to_dict(),
"metadata": self._metadata,
}
[docs]
@classmethod
def from_dict(cls, data: BioMolDict) -> Self:
"""Create a BioMol object from a dictionary.
Parameters
----------
data: BioMolDict
A dictionary containing the data to create the BioMol object.
Returns
-------
BioMol
The created BioMol object.
"""
return cls(
FeatureContainer.from_dict(data["atoms"]),
FeatureContainer.from_dict(data["residues"]),
FeatureContainer.from_dict(data["chains"]),
IndexTable.from_dict(data["index_table"]),
data["metadata"],
)
[docs]
def to_bytes(self, level: int = 6) -> bytes:
"""Serialize the container to zstd-compressed bytes.
Parameters
----------
level: int, optional
The compression level for zstd (default is 6).
"""
return to_bytes(self.to_dict(), level=level)
[docs]
@classmethod
def from_bytes(cls, byte_data: bytes) -> Self:
"""Deserialize the container from zstd-compressed bytes."""
return cls.from_dict(load_bytes(byte_data)) # pyright: ignore[reportArgumentType]
[docs]
def update_features(
self,
level: StructureLevel,
**features: Feature | NDArray[Any],
) -> Self:
"""Update features at a specific structure level.
Parameters
----------
level: StructureLevel
The structure level at which to update features.
**features: Feature | NDArray[Any]
Key-value pairs of features to update. Values can be either Feature
objects or numpy arrays (which will be converted to NodeFeature).
Returns
-------
mol
Updated BioMol object.
Notes
-----
Does not modify the current BioMol instance; instead, returns a new one.
Examples
--------
.. code-block:: python
mol = BioMol(...)
new_mol = mol.update_features(
StructureLevel.ATOM,
coord=mol.atoms.coord + 1.0,
)
"""
containers = {
StructureLevel.ATOM: self._atom_container,
StructureLevel.RESIDUE: self._residue_container,
StructureLevel.CHAIN: self._chain_container,
}
containers[level] = containers[level].update(**features)
return self.__class__(
containers[StructureLevel.ATOM],
containers[StructureLevel.RESIDUE],
containers[StructureLevel.CHAIN],
self.index_table,
self.metadata,
)
[docs]
def remove_features(self, level: StructureLevel, *keys: str) -> Self:
"""Remove features at a specific structure level.
Parameters
----------
level: StructureLevel
The structure level at which to remove features.
*keys: str
Keys of the features to remove.
Returns
-------
mol
Updated BioMol object.
Notes
-----
Does not modify the current BioMol instance; instead, returns a new one.
Examples
--------
.. code-block:: python
mol = BioMol(...)
new_mol = mol.remove_features(StructureLevel.ATOM, "coord", "element")
"""
containers = {
StructureLevel.ATOM: self._atom_container,
StructureLevel.RESIDUE: self._residue_container,
StructureLevel.CHAIN: self._chain_container,
}
containers[level] = containers[level].remove(*keys)
return self.__class__(
containers[StructureLevel.ATOM],
containers[StructureLevel.RESIDUE],
containers[StructureLevel.CHAIN],
self.index_table,
self.metadata,
)
[docs]
@classmethod
def concat(cls, mols: list[Self]) -> Self:
"""Concatenate multiple BioMol objects.
Parameters
----------
mols: list[Self]
List of BioMol objects to concatenate.
Returns
-------
Self
Concatenated BioMol object.
Notes
-----
All containers must have the same set of feature keys.
Metadata from the first BioMol object is retained.
Always returns a new BioMol instance, even if only one object is provided.
Examples
--------
.. code-block:: python
mol1 = BioMol(...)
mol2 = BioMol(...)
concatenated_mol = BioMol.concat([mol1, mol2])
"""
if not mols:
msg = "Cannot concatenate an empty list of BioMol objects."
raise ValueError(msg)
atom_containers = [mol.get_container(StructureLevel.ATOM) for mol in mols]
residue_containers = [mol.get_container(StructureLevel.RESIDUE) for mol in mols]
chain_containers = [mol.get_container(StructureLevel.CHAIN) for mol in mols]
residue_counts = [len(container) for container in residue_containers]
residue_offsets = np.cumsum([0, *residue_counts[:-1]])
atom_to_res = [
mol.index_table.atom_to_res + offset
for mol, offset in zip(mols, residue_offsets, strict=True)
]
chain_counts = [len(container) for container in chain_containers]
chain_offsets = np.cumsum([0, *chain_counts[:-1]])
res_to_chain = [
mol.index_table.res_to_chain + offset
for mol, offset in zip(mols, chain_offsets, strict=True)
]
concatenated_table = IndexTable.from_parents(
atom_to_res=np.concatenate(atom_to_res, axis=0),
res_to_chain=np.concatenate(res_to_chain, axis=0),
n_chain=sum(chain_counts),
)
return cls(
FeatureContainer.concat(atom_containers),
FeatureContainer.concat(residue_containers),
FeatureContainer.concat(chain_containers),
concatenated_table,
metadata=mols[0].metadata.copy(),
)
[docs]
def copy(self) -> Self:
"""Create a deep copy of the BioMol."""
return self.__class__(
self._atom_container.copy(),
self._residue_container.copy(),
self._chain_container.copy(),
self._index.copy(),
self._metadata.copy(),
)
def _check_lengths(self) -> None:
"""Check if the lengths of the containers and index table are consistent."""
if len(self._atom_container) != len(self._index.atom_to_res):
msg = (
"Atom length mismatch: "
f"atom_container has length {len(self._atom_container)}, "
f"but index table has length {len(self._index.atom_to_res)}."
)
raise IndexMismatchError(msg)
if len(self._residue_container) != len(self._index.res_to_chain):
msg = (
"Residue length mismatch: "
f"residue_container has length {len(self._residue_container)}, "
f"but index table has length {len(self._index.res_to_chain)}."
)
raise IndexMismatchError(msg)
if len(self._chain_container) != len(self._index.chain_res_indptr) - 1:
msg = (
"Chain length mismatch: "
f"chain_container has length {len(self._chain_container)}, "
f"but index table has length {len(self._index.chain_res_indptr) - 1}."
)
raise IndexMismatchError(msg)
def __repr__(self) -> str:
"""Return a string representation of the BioMol object."""
return (
f"<{self.__class__.__name__} with {len(self._atom_container)} atoms, "
f"{len(self._residue_container)} residues, "
f"and {len(self._chain_container)} chains>"
)
def __add__(self, other: Self) -> Self:
"""Concatenate two BioMol objects using the + operator.
Note
----
For concatenating more than two objects, use BioMol.concat([mol1, mol2, mol3])
for better performance.
"""
return self.concat([self, other])