from __future__ import annotations
from dataclasses import asdict
from typing import TYPE_CHECKING, Any, cast
import numpy as np
from biomol.exceptions import FeatureKeyError, IndexMismatchError, IndexOutOfBoundsError
from .feature import EdgeFeature, Feature, NodeFeature
if TYPE_CHECKING:
from collections.abc import Mapping
from numpy.typing import NDArray
from .types import FeatureContainerDict
[docs]
class FeatureContainer:
"""Container for holding either node or edge features.
Parameters
----------
features: Mapping[str, Feature]
A mapping of feature keys to Feature objects. Features can be either NodeFeature
or EdgeFeature.
Notes
-----
features must contain at least one NodeFeature. All NodeFeatures must have the same
length.
"""
def __init__(self, features: Mapping[str, Feature]) -> None:
self._features = dict(features)
self._check_node_lengths()
self._check_edge_indices()
def __len__(self) -> int:
"""Return the number of nodes in the container."""
node_lengths = {
len(f.value) for f in self._features.values() if isinstance(f, NodeFeature)
}
return node_lengths.pop()
def __getitem__(self, key: str) -> Feature:
"""Get a feature by its key."""
if key in self._features:
return self._features[key]
raise FeatureKeyError(key)
def __contains__(self, key: str) -> bool:
"""Check if a feature key exists in the container."""
return key in self._features
def __repr__(self) -> str:
"""Return a string representation of the container."""
return f"FeatureContainer(keys={list(self._features.keys())})"
[docs]
def keys(self) -> list[str]:
"""List of all features keys in the container."""
return list(self._features.keys())
[docs]
def crop(self, indices: NDArray[np.integer]) -> FeatureContainer:
"""Crop all features to only include the specified indices.
Parameters
----------
indices: NDArray[np.integer]
1D array of global node indices to keep. Only integer arrays is allowed.
"""
return FeatureContainer(
{key: feat.crop(indices) for key, feat in self._features.items()},
)
[docs]
def to_dict(self) -> FeatureContainerDict:
"""Convert the container to a dictionary."""
nodes = {
key: asdict(values)
for key, values in self._features.items()
if isinstance(values, NodeFeature)
}
edges = {
key: asdict(values)
for key, values in self._features.items()
if isinstance(values, EdgeFeature)
}
return {"nodes": nodes, "edges": edges} # pyright: ignore[reportReturnType]
[docs]
@classmethod
def from_dict(cls, data: FeatureContainerDict) -> FeatureContainer:
"""Create a FeatureContainer from a dictionary.
Parameters
----------
data : FeatureContainerDict
Dictionary containing node and edge features.
"""
nodes = {
key: NodeFeature(**values) for key, values in data.get("nodes", {}).items()
}
edges = {
key: EdgeFeature(**values) for key, values in data.get("edges", {}).items()
}
if nodes.keys() & edges.keys():
overlap_keys = nodes.keys() & edges.keys()
msg = f"Feature keys cannot be both node and edge features: {overlap_keys}"
raise FeatureKeyError(msg)
return FeatureContainer(features={**nodes, **edges})
[docs]
def update(self, **features: Feature | NDArray[Any]) -> FeatureContainer:
"""Update the container with new or modified features.
Parameters
----------
**features: Feature | NDArray[Any]
Key-value pairs of features to add or update. Values can be either Feature
objects or numpy arrays (which will be converted to NodeFeature).
Notes
-----
This method returns a new FeatureContainer instance and does not modify the
current instance.
Examples
--------
.. code-block:: python
container = FeatureContainer(...)
new_container = container.update(coord=container["coord"] + 1.0)
"""
_features = dict(self._features)
_features.update(
{
key: value if isinstance(value, Feature) else NodeFeature(value)
for key, value in features.items()
},
)
return FeatureContainer(_features)
[docs]
def remove(self, *keys: str) -> FeatureContainer:
"""Remove features by their keys.
Parameters
----------
*keys: str
Keys of the features to remove.
Notes
-----
This method returns a new FeatureContainer instance and does not modify the
current instance.
Examples
--------
.. code-block:: python
container = FeatureContainer(...)
new_container = container.remove("coord", "element")
"""
_features = dict(self._features)
for key in keys:
if key not in _features:
raise FeatureKeyError(key)
del _features[key]
return FeatureContainer(_features)
[docs]
@classmethod
def concat(cls, containers: list[FeatureContainer]) -> FeatureContainer:
"""Concatenate multiple FeatureContainer instances.
Parameters
----------
containers: list[FeatureContainer]
List of FeatureContainer instances to concatenate.
Returns
-------
FeatureContainer
Concatenated FeatureContainer.
Notes
-----
All containers must have the same set of feature keys. Always returns a new
FeatureContainer instance, even if only one container is provided.
Examples
--------
.. code-block:: python
container1 = FeatureContainer(...)
container2 = FeatureContainer(...)
concatenated = FeatureContainer.concat([container1, container2])
"""
if not containers:
msg = "No FeatureContainer instances provided for concatenation."
raise ValueError(msg)
base_keys = containers[0].keys()
for container in containers[1:]:
if set(container.keys()) != set(base_keys):
msg = (
"All containers must have the same feature keys. "
f"Missing keys: {set(base_keys) - set(container.keys())}. "
f"Extra keys: {set(container.keys()) - set(base_keys)}."
)
raise FeatureKeyError(msg)
new_features: dict[str, Feature] = {}
for key in base_keys:
features = [container[key] for container in containers]
if all(isinstance(feat, NodeFeature) for feat in features):
new_features[key] = NodeFeature(
np.concatenate([feature.value for feature in features], axis=0),
)
elif all(isinstance(feat, EdgeFeature) for feat in features):
features = cast("list[EdgeFeature]", features)
counts = [len(c) for c in containers]
offsets = np.cumsum([0, *counts[:-1]])
all_src = [
feature.src_indices + offset
for feature, offset in zip(features, offsets, strict=True)
]
all_dst = [
feature.dst_indices + offset
for feature, offset in zip(features, offsets, strict=True)
]
new_features[key] = EdgeFeature(
np.concatenate([feature.value for feature in features], axis=0),
src_indices=np.concatenate(all_src, axis=0),
dst_indices=np.concatenate(all_dst, axis=0),
)
else:
msg = (
f"Feature '{key}' has mixed types across containers: "
f"{ {type(f) for f in features} }"
)
raise FeatureKeyError(msg)
return FeatureContainer(new_features)
[docs]
def copy(self) -> FeatureContainer:
"""Create a deep copy of the FeatureContainer."""
return FeatureContainer(
{key: feat.copy() for key, feat in self._features.items()},
)
def _check_node_lengths(self) -> None:
node_lengths = {
len(f.value) for f in self._features.values() if isinstance(f, NodeFeature)
}
if not node_lengths:
msg = "FeatureContainer must contain at least one node feature."
raise FeatureKeyError(msg)
if len(node_lengths) > 1:
msg = f"Inconsistent node feature lengths {node_lengths}"
raise IndexMismatchError(msg)
def _check_edge_indices(self) -> None:
length = len(self)
for key, feat in self._features.items():
if not isinstance(feat, EdgeFeature):
continue
if np.any(feat.src_indices >= length) or np.any(feat.dst_indices >= length):
msg = (
f"Pair feature '{key}' has out-of-bounds indices. "
f"Max index allowed is {length - 1}, "
f"but got src_indices max={feat.src_indices.max()} and "
f"dst_indices max={feat.dst_indices.max()}."
)
raise IndexOutOfBoundsError(msg)