Source code for deepqmc.gnn.edge_features
from typing import Optional, Protocol
import jax
import jax.numpy as jnp
from ..utils import norm
[docs]
class EdgeFeature(Protocol):
r"""Base class for all edge features."""
def __call__(self, d: jax.Array) -> jax.Array:
r"""Return the edge features of the given difference vector."""
...
def __len__(self) -> int:
"""Return the length of the output feature vector."""
...
[docs]
class DifferenceEdgeFeature(EdgeFeature):
r"""Return the difference vector as the edge features.
Args:
log_rescale (bool): whether to rescale the features by
:math:`\log(1 + d) / d` where :math:`d` is the length of the edge.
"""
def __init__(self, *, log_rescale=False):
self.log_rescale = log_rescale
def __call__(self, d: jax.Array) -> jax.Array:
if self.log_rescale:
r = norm(d, safe=True)
d *= (jnp.log1p(r) / r)[..., None]
return d
def __len__(self) -> int:
return 3
[docs]
class DistancePowerEdgeFeature(EdgeFeature):
r"""Return powers of the distance as edge features.
Args:
powers (list[float]): a :data:`list` of powers to apply to the edge length.
eps (float | None): a small value to add to the denominator when the power is
negative.
log_rescale (bool): whether to rescale the features by
:math:`\log(1 + d) / d` where :math:`d` is the length of the edge.
"""
def __init__(
self,
*,
powers: list[float],
eps: Optional[float] = None,
log_rescale: bool = False,
):
if any(p < 0 for p in powers):
assert eps is not None
self.powers = jnp.asarray(powers)
self.eps = eps or 0.0
self.log_rescale = log_rescale
def __call__(self, d: jax.Array) -> jax.Array:
r = norm(d, safe=True)
powers = jnp.where(
self.powers > 0,
r[..., None] ** self.powers,
1 / (r[..., None] ** (-self.powers) + self.eps),
)
if self.log_rescale:
powers *= (jnp.log1p(r) / r)[..., None]
return powers
def __len__(self) -> int:
return len(self.powers)
[docs]
class GaussianEdgeFeature(EdgeFeature):
r"""
Expand the distance in a Gaussian basis.
Args:
n_gaussian (int): the number of gaussians to use,
consequently the length of the feature vector
radius (float): the radius within which to place gaussians
offset (bool): whether to offset the position of the first
Gaussian from zero.
"""
def __init__(self, *, n_gaussian: int, radius: float, offset: bool):
delta = 1 / (2 * n_gaussian) if offset else 0
qs = jnp.linspace(delta, 1 - delta, n_gaussian)
self.mus = radius * qs**2
self.sigmas = (1 + radius * qs) / 7
def __call__(self, d: jax.Array) -> jax.Array:
r = norm(d, safe=True)
gaussians = jnp.exp(-((r[..., None] - self.mus) ** 2) / self.sigmas**2)
return gaussians
def __len__(self) -> int:
return len(self.mus)
[docs]
class CombinedEdgeFeature(EdgeFeature):
r"""Combine multiple edge features.
Args:
features (list[~deepqmc.gnn.edge_features.EdgeFeature]): a list of edge feature
objects to combine.
"""
def __init__(self, *, features: list[EdgeFeature]):
self.features = features
def __call__(self, d: jax.Array) -> jax.Array:
return jnp.concatenate([f(d) for f in self.features], axis=-1)
def __len__(self) -> int:
return sum(map(len, self.features))