from collections import namedtuple
import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
GraphNodes = namedtuple('GraphNodes', 'nuclei electrons')
Graph = namedtuple('Graph', 'nodes edges')
__all__ = [
'GraphEdgeBuilder',
'MolecularGraphEdgeBuilder',
'GraphUpdate',
]
def offdiagonal_sender_idx(n_node):
return (
jnp.arange(n_node)[None, :] <= jnp.arange(n_node - 1)[:, None]
) + jnp.arange(n_node - 1)[:, None]
def compute_edges(pos_sender, pos_receiver, filter_diagonal):
diffs = pos_receiver[..., None, :, :] - pos_sender[..., None, :]
if filter_diagonal:
assert pos_sender.shape[-2] == pos_receiver.shape[-2]
n_node = pos_sender.shape[-2]
receiver_idx = jnp.broadcast_to(jnp.arange(n_node)[None], (n_node - 1, n_node))
sender_idx = offdiagonal_sender_idx(n_node)
diffs = diffs[..., sender_idx, receiver_idx, :]
return diffs
[docs]
def GraphEdgeBuilder(
mask_self,
):
r"""
Create a function that builds graph edges.
Args:
mask_self (bool): whether to mask edges that begin and end in the same node.
"""
def build(pos_sender, pos_receiver):
r"""
Build graph edges.
Args:
pos_sender (float, (:math:`N_{nodes}`, 3)): coordinates of graph nodes
that send edges.
pos_receiver (float, (:math:`M_{nodes}`, 3)): coordinates of graph nodes
that receive edges.
Returns:
A :class:`~deepqmc.gnn.graph.GraphEdges` instance.
"""
assert pos_sender.shape[-1] == 3 and pos_receiver.shape[-1] == 3
assert len(pos_sender.shape) == 2
assert not mask_self or pos_sender.shape[0] == pos_receiver.shape[0]
return compute_edges(pos_sender, pos_receiver, mask_self)
return build
[docs]
def MolecularGraphEdgeBuilder(n_nuc, n_up, n_down, edge_types, *, self_interaction):
r"""
Create a function that builds many types of molecular edges.
Args:
n_nuc (int): number of nuclei.
n_up (int): number of spin-up electrons.
n_down (int): number of spin-down electrons.
edge_types (list[str]): list of edge type names to build. Possible names are:
- ``'nn'``: nuclei->nuclei edges
- ``'ne'``: nuclei->electrons edges
- ``'en'``: electrons->nuclei edges
- ``'same'``: edges between same-spin electrons
- ``'anti'``: edges between opposite-spin electrons
- ``'up'``: edges going from spin-up electrons to all electrons
- ``'down'``: edges going from spin-down electrons to all electrons
self_interaction (bool): whether edges between a particle and itself are
considered
"""
builder_mapping = {
'nn': ['nn'],
'ne': ['ne'],
'en': ['en'],
'same': ['uu', 'dd'],
'anti': ['ud', 'du'],
'up': ['up'],
'down': ['down'],
}
fix_kwargs_of_builder_type = {
'nn': {
'mask_self': not self_interaction,
},
'ne': {
'mask_self': False,
},
'en': {
'mask_self': False,
},
'uu': {
'mask_self': not self_interaction,
},
'dd': {
'mask_self': not self_interaction,
},
'ud': {
'mask_self': False,
},
'du': {
'mask_self': False,
},
'up': {'mask_self': False},
'down': {'mask_self': False},
}
builders = {
builder_type: GraphEdgeBuilder(
**fix_kwargs_of_builder_type[builder_type],
)
for edge_type in edge_types
for builder_type in builder_mapping[edge_type]
}
build_rules = {
'nn': lambda pc: SimpleGraphEdges(builders['nn'](pc.R, pc.R)),
'ne': lambda pc: SimpleGraphEdges(builders['ne'](pc.R, pc.r)),
'en': lambda pc: SimpleGraphEdges(builders['en'](pc.r, pc.R)),
'same': lambda pc: SameGraphEdges(
builders['uu'](pc.r[:n_up], pc.r[:n_up]),
builders['dd'](pc.r[n_up:], pc.r[n_up:]),
),
'anti': lambda pc: AntiGraphEdges(
builders['du'](pc.r[n_up:], pc.r[:n_up]),
builders['ud'](pc.r[:n_up], pc.r[n_up:]),
),
'up': lambda pc: UpGraphEdges(builders['up'](pc.r[:n_up], pc.r)),
'down': lambda pc: DownGraphEdges(builders['down'](pc.r[n_up:], pc.r)),
}
def build(phys_conf):
r"""
Build many types of molecular graph edges.
Args:
phys_conf (~deepqmc.types.PhysicalConfiguration): the physical
configuration of the molecule.
occupancies (dict): mapping of edge type names to arrays where the occupancy
of the given edge type is stored.
"""
assert phys_conf.r.shape[0] == n_up + n_down
edges = {
edge_type: build_rules[edge_type](phys_conf) for edge_type in edge_types
}
return edges
return build
[docs]
def GraphUpdate(
aggregate_edges_for_nodes_fn,
update_nodes_fn=None,
update_edges_fn=None,
):
r"""
Create a function that updates a graph.
The update function is tailored to be used in GNNs.
Args:
aggregate_edges_for_nodes_fn (bool): whether to perform the aggregation
of edges for nodes.
update_nodes_fn (~collections.abc.Callable): optional, function that updates the
nodes.
update_edges_fn (~collections.abc.Callable): optional, function that updates the
edges.
"""
def update_graph(graph):
nodes, edges = graph
if update_nodes_fn:
aggregated_edges = aggregate_edges_for_nodes_fn(nodes, edges)
nodes = update_nodes_fn(nodes, aggregated_edges)
if update_edges_fn:
edges = update_edges_fn(edges)
return Graph(nodes, edges)
return update_graph
class GraphEdges:
@property
def single_array(self):
raise NotImplementedError
def update_from_single_array(self, array):
raise NotImplementedError
def sum_senders(self, normalize=False):
raise NotImplementedError
def convolve(self, nodes, normalize=False):
raise NotImplementedError
@jdc.pytree_dataclass
class SimpleGraphEdges(GraphEdges):
edges: jax.Array
@property
def single_array(self):
return self.edges
def update_from_single_array(self, array):
return self.__class__(array)
def sum_senders(self, normalize=False):
return (jnp.mean if normalize else jnp.sum)(self.edges, axis=-3)
def convolve(self, nodes, normalize=False):
edge_node_product = self.edges * nodes[:, None]
return self.__class__(edge_node_product).sum_senders(normalize)
@jdc.pytree_dataclass
class UpGraphEdges(SimpleGraphEdges):
def convolve(self, nodes, normalize=False):
up = self.edges * nodes[: self.edges.shape[-3], None]
return self.__class__(up).sum_senders(normalize)
@jdc.pytree_dataclass
class DownGraphEdges(SimpleGraphEdges):
def convolve(self, nodes, normalize=False):
down = self.edges * nodes[-self.edges.shape[-3] :, None]
return self.__class__(down).sum_senders(normalize)
@jdc.pytree_dataclass
class SameGraphEdges(GraphEdges):
uu: jax.Array
dd: jax.Array
@property
def single_array(self):
batch_dims = self.uu.shape[:-3]
return jnp.concatenate(
[
self.uu.reshape(*batch_dims, -1, self.uu.shape[-1]),
self.dd.reshape(*batch_dims, -1, self.dd.shape[-1]),
],
axis=-2,
)
def update_from_single_array(self, array):
n_up = self.uu.shape[-2]
n_down = self.dd.shape[-2]
n_sender_up = self.uu.shape[-3]
n_sender_down = self.dd.shape[-3]
uu, dd = jnp.split(array, (n_up * n_sender_up,), axis=-2)
uu = uu.reshape(*uu.shape[:-2], n_sender_up, n_up, uu.shape[-1])
dd = dd.reshape(*dd.shape[:-2], n_sender_down, n_down, dd.shape[-1])
return self.__class__(uu, dd)
def sum_senders(self, normalize=False):
norm_uu, norm_dd = (
max(x.shape[-3], 1) if normalize else 1 for x in (self.uu, self.dd)
)
up, down = (
jnp.sum(self.uu, axis=-3) / norm_uu,
jnp.sum(self.dd, axis=-3) / norm_dd,
)
return jnp.concatenate([up, down], axis=-2)
def convolve(self, nodes, normalize=False):
self_interaction = self.uu.shape[-3] == self.uu.shape[-2]
up_node_idx = (
(slice(None, self.uu.shape[-2]), None)
if self_interaction
else offdiagonal_sender_idx(self.uu.shape[-2])
)
down_node_idx = (
(slice(self.uu.shape[-2], None), None)
if self_interaction
else self.uu.shape[-2] + offdiagonal_sender_idx(self.dd.shape[-2])
)
uu = self.uu * nodes[up_node_idx]
dd = self.dd * nodes[down_node_idx]
return self.__class__(uu, dd).sum_senders(normalize)
@jdc.pytree_dataclass
class AntiGraphEdges(GraphEdges):
du: jax.Array
ud: jax.Array
@property
def single_array(self):
batch_dims = self.du.shape[:-3]
return jnp.concatenate(
[
self.du.reshape(*batch_dims, -1, self.du.shape[-1]),
self.ud.reshape(*batch_dims, -1, self.ud.shape[-1]),
],
axis=-2,
)
def update_from_single_array(self, array):
n_up = self.du.shape[-2]
n_down = self.ud.shape[-2]
du, ud = jnp.split(array, (n_up * n_down,))
du = du.reshape(*du.shape[:-2], n_down, n_up, du.shape[-1])
ud = ud.reshape(*ud.shape[:-2], n_up, n_down, ud.shape[-1])
return self.__class__(du, ud)
def sum_senders(self, normalize=False):
norm_du, norm_ud = (
max(x.shape[-3], 1) if normalize else 1 for x in (self.du, self.ud)
)
up, down = (
jnp.sum(self.du, axis=-3) / norm_du,
jnp.sum(self.ud, axis=-3) / norm_ud,
)
return jnp.concatenate([up, down], axis=-2)
def convolve(self, nodes, normalize=False):
du = self.du * nodes[self.du.shape[-2] :, None]
ud = self.ud * nodes[: self.du.shape[-2], None]
return self.__class__(du, ud).sum_senders(normalize)