Source code for deepqmc.gnn.edge_features
import jax.numpy as jnp
from ..utils import norm
[docs]class EdgeFeature:
r"""Base class for all edge features."""
def __len__(self):
"""Return the length of the output feature vector."""
raise NotImplementedError
[docs]class DifferenceEdgeFeature(EdgeFeature):
"""Return the difference vector as the edge features."""
def __init__(self, *, log_rescale=False):
self.log_rescale = log_rescale
def __call__(self, d):
if self.log_rescale:
r = norm(d, safe=True)
d *= (jnp.log1p(r) / r)[..., None]
return d
def __len__(self):
return 3
[docs]class DistancePowerEdgeFeature(EdgeFeature):
"""Return powers of the distance as edge features."""
def __init__(self, *, powers, eps=None, log_rescale=False):
if any(p < 0 for p in powers):
assert eps is not None
self.powers = jnp.asarray(powers)
self.eps = eps or 0.0
self.log_rescale = log_rescale
def __call__(self, d):
r = norm(d, safe=True)
powers = jnp.where(
self.powers > 0,
r[..., None] ** self.powers,
1 / (r[..., None] ** (-self.powers) + self.eps),
)
if self.log_rescale:
powers *= (jnp.log1p(r) / r)[..., None]
return powers
def __len__(self):
return len(self.powers)
[docs]class GaussianEdgeFeature(EdgeFeature):
r"""
Expand the distance in a Gaussian basis.
Args:
n_gaussian (int): the number of gaussians to use,
consequently the length of the feature vector
radius (float): the radius within which to place gaussians
offset (bool): whether to offset the position of the first
Gaussian from zero.
"""
def __init__(self, *, n_gaussian, radius, offset):
delta = 1 / (2 * n_gaussian) if offset else 0
qs = jnp.linspace(delta, 1 - delta, n_gaussian)
self.mus = radius * qs**2
self.sigmas = (1 + radius * qs) / 7
def __call__(self, d):
r = norm(d, safe=True)
gaussians = jnp.exp(-((r[..., None] - self.mus) ** 2) / self.sigmas**2)
return gaussians
def __len__(self):
return len(self.mus)
[docs]class CombinedEdgeFeature(EdgeFeature):
r"""Combine multiple edge features.
Args:
features (Sequence): a :data:`Sequence` of edge feature objects
to combine.
"""
def __init__(self, *, features):
self.features = features
def __call__(self, d):
return jnp.concatenate([f(d) for f in self.features], axis=-1)
def __len__(self):
return sum(map(len, self.features))