from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Any, Final
import numpy as np
from numpy.lib.mixins import NDArrayOperatorsMixin
from typing_extensions import Self, override
from biomol.exceptions import (
FeatureOperationError,
IndexInvalidError,
IndexMismatchError,
)
if TYPE_CHECKING:
from numpy.typing import DTypeLike, NDArray
_LOGICAL_UFUNCS: Final[set[np.ufunc]] = {
np.logical_and,
np.logical_or,
np.logical_xor,
np.logical_not,
}
_COMPARISON_UFUNCS: Final[set[np.ufunc]] = {
np.equal,
np.not_equal,
np.less,
np.less_equal,
np.greater,
np.greater_equal,
}
[docs]
@dataclass(frozen=True, slots=True, eq=False)
class Feature(NDArrayOperatorsMixin, ABC):
"""A base class for features in a structure.
This class supports numpy operations and can be indexed and cropped.
"""
value: NDArray[Any]
"""The underlying numpy array representing the feature data."""
__array_priority__ = 1000
@property
def shape(self) -> tuple[int, ...]:
"""Return the shape of the feature."""
return self.value.shape
@property
def ndim(self) -> int:
"""Return the number of dimensions of the feature."""
return self.value.ndim
@property
def dtype(self) -> DTypeLike:
"""Return the data type of the feature."""
return self.value.dtype
@property
def size(self) -> int:
"""Return the total number of elements in the feature."""
return self.value.size
[docs]
def mean(self, axis: int | None = None, **kwargs: Any) -> Any: # noqa: ANN401
"""Return the mean of the feature along the specified axis."""
return self.value.mean(axis=axis, **kwargs)
[docs]
def sum(self, axis: int | None = None, **kwargs: Any) -> Any: # noqa: ANN401
"""Return the sum of the feature along the specified axis."""
return self.value.sum(axis=axis, **kwargs)
[docs]
def min(self, axis: int | None = None, **kwargs: Any) -> Any: # noqa: ANN401
"""Return the minimum of the feature along the specified axis."""
return self.value.min(axis=axis, **kwargs)
[docs]
def max(self, axis: int | None = None, **kwargs: Any) -> Any: # noqa: ANN401
"""Return the maximum of the feature along the specified axis."""
return self.value.max(axis=axis, **kwargs)
[docs]
@abstractmethod
def crop(self, indices: NDArray[np.integer]) -> Self:
"""Crop the feature to only include the specified indices.
Parameters
----------
indices: NDArray[np.integer]
1D array of node indices to keep. Only integer arrays is allowed.
"""
[docs]
@abstractmethod
def copy(self) -> Self:
"""Return a deep copy of the feature.
Returns
-------
Self
A new instance with copied numpy arrays.
"""
@abstractmethod
def __getitem__(self, key: Any) -> Self: # noqa: ANN401
"""Get a subset of the feature."""
def __len__(self) -> int:
"""Return the number of entries in the feature."""
return len(self.value)
def __bool__(self) -> bool:
"""Prevent ambiguous truth value evaluation."""
return bool(self.value)
def __array__(self, dtype: DTypeLike | None = None) -> NDArray[Any]:
"""Convert the feature to a numpy array.
This method is called when numpy functions are applied to the feature.
"""
return np.asarray(self.value, dtype=dtype)
def __array_ufunc__(
self,
ufunc: np.ufunc,
method: str,
*inputs: Any, # noqa: ANN401
**kwargs: Any, # noqa: ANN401
) -> Any: # noqa: ANN401
"""Support numpy universal functions (ufuncs).
This method is called when numpy ufuncs are applied to the feature.
"""
if method == "at":
msg = (
f"{type(self).__name__} is immutable; "
"in-place operations are not supported."
)
raise FeatureOperationError(msg)
if "out" in kwargs and kwargs["out"] is not None:
outs = kwargs["out"]
if not isinstance(outs, tuple):
outs = (outs,)
if any(isinstance(o, Feature) for o in outs):
msg = (
f"{type(self).__name__} is immutable; "
"in-place operations are not supported."
)
raise FeatureOperationError(msg)
args = [x.value if isinstance(x, Feature) else x for x in inputs]
res = getattr(ufunc, method)(*args, **kwargs)
if method in ("reduce", "reduceat", "accumulate", "outer", "inner"):
return res
if ufunc in _COMPARISON_UFUNCS or ufunc in _LOGICAL_UFUNCS:
return res
if isinstance(res, tuple):
return tuple(
replace(self, value=r)
if isinstance(r, np.ndarray) and r.shape == self.shape
else r
for r in res
)
if isinstance(res, np.ndarray) and res.shape == self.shape:
return replace(self, value=res)
return res
[docs]
@dataclass(frozen=True, slots=True, eq=False)
class NodeFeature(Feature):
"""A feature associated with nodes in a structure.
Parameters
----------
value: np.ndarray
A numpy array where the first dimension corresponds to the nodes.
"""
[docs]
@override
def crop(self, indices: NDArray[np.integer]) -> Self:
return self[indices]
[docs]
@override
def copy(self) -> Self:
return replace(self, value=self.value.copy())
@override
def __getitem__(self, key: Any) -> Self:
return replace(self, value=self.value[key])
[docs]
@dataclass(frozen=True, slots=True, eq=False)
class EdgeFeature(Feature):
"""A feature associated with edges (pairs of nodes) in a structure.
Parameters
----------
value: np.ndarray
A numpy array where the first dimension corresponds to the edges.
src_indices: NDArray[np.integer]
A 1D numpy array of source node indices for each edge.
dst_indices: NDArray[np.integer]
A 1D numpy array of destination node indices for each edge.
"""
src_indices: NDArray[np.integer]
"""Source node indices of the edges."""
dst_indices: NDArray[np.integer]
"""Destination node indices of the edges."""
def __post_init__(self) -> None: # noqa: D105
if not (self.src_indices.ndim == 1 and self.dst_indices.ndim == 1):
msg = (
"src_indices and dst_indices must be 1D arrays. Got "
f"src_indices={self.src_indices.ndim}, "
f"dst_indices={self.dst_indices.ndim}"
)
raise IndexInvalidError(msg)
if not (len(self.value) == len(self.src_indices) == len(self.dst_indices)):
msg = (
"All arrays must have the same length. Got "
f"value={len(self.value)}, src_indices={len(self.src_indices)}, "
f"dst_indices={len(self.dst_indices)}"
)
raise IndexMismatchError(msg)
if np.any(self.src_indices < 0) or np.any(self.dst_indices < 0):
msg = "src_indices and dst_indices must be non-negative."
raise IndexInvalidError(msg)
@property
def src(self) -> NDArray[np.integer]:
"""Return the source node indices of the edges."""
return self.src_indices
@property
def dst(self) -> NDArray[np.integer]:
"""Return the destination node indices of the edges."""
return self.dst_indices
@property
def nodes(self) -> NDArray[np.integer]:
"""Return the unique node indices involved in the edges."""
return np.unique(np.concatenate([self.src_indices, self.dst_indices]))
[docs]
@override
def crop(self, indices: NDArray[np.integer]) -> Self:
"""Crop the feature to only include the specified indices.
Keep only pairs (i, j) whose both endpoints are in `indices`.
Parameters
----------
indices: NDArray[np.integer]
1D array of node indices to keep. Only integer arrays is allowed.
"""
if not isinstance(indices, np.ndarray):
msg = f"Indices must be a numpy.ndarray, got {type(indices)}"
raise IndexInvalidError(msg)
if indices.ndim != 1:
msg = f"Indices must be a 1D array, got {indices.ndim}D array"
raise IndexInvalidError(msg)
if not np.issubdtype(indices.dtype, np.integer):
msg = f"Indices must be a integer array, got {indices.dtype}"
raise IndexInvalidError(msg)
if np.any(indices < 0):
msg = "Negative indices are not allowed."
raise IndexInvalidError(msg)
if self.value.size == 0 or indices.size == 0:
return self._empty_like()
kept, idx = np.unique(indices, return_index=True)
src_in_kept = np.isin(self.src_indices, kept)
dst_in_kept = np.isin(self.dst_indices, kept)
row_mask = src_in_kept & dst_in_kept
if not row_mask.any():
return self._empty_like()
new_src = idx[np.searchsorted(kept, self.src_indices[row_mask])]
new_dst = idx[np.searchsorted(kept, self.dst_indices[row_mask])]
return replace(
self,
value=self.value[row_mask],
src_indices=new_src,
dst_indices=new_dst,
)
[docs]
@override
def copy(self) -> Self:
"""Return a deep copy of the edge feature.
Returns
-------
Self
A new instance with all numpy arrays copied.
"""
return replace(
self,
value=self.value.copy(),
src_indices=self.src_indices.copy(),
dst_indices=self.dst_indices.copy(),
)
def _empty_like(self) -> Self:
empty = np.empty((0, *self.value.shape[1:]), dtype=self.value.dtype)
ind = np.empty((0,), dtype=self.src_indices.dtype)
return replace(self, value=empty, src_indices=ind, dst_indices=ind)
@override
def __getitem__(self, key: Any) -> Self:
return replace(
self,
value=self.value[key],
src_indices=self.src_indices[key],
dst_indices=self.dst_indices[key],
)