Source code for deepqmc.gnn.graph

from collections import namedtuple

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc

GraphNodes = namedtuple('GraphNodes', 'nuclei electrons')
Graph = namedtuple('Graph', 'nodes edges')

__all__ = [
    'GraphEdgeBuilder',
    'MolecularGraphEdgeBuilder',
    'GraphUpdate',
]


def offdiagonal_sender_idx(n_node):
    return (
        jnp.arange(n_node)[None, :] <= jnp.arange(n_node - 1)[:, None]
    ) + jnp.arange(n_node - 1)[:, None]


def compute_edges(pos_sender, pos_receiver, filter_diagonal):
    diffs = pos_receiver[..., None, :, :] - pos_sender[..., None, :]
    if filter_diagonal:
        assert pos_sender.shape[-2] == pos_receiver.shape[-2]
        n_node = pos_sender.shape[-2]
        receiver_idx = jnp.broadcast_to(jnp.arange(n_node)[None], (n_node - 1, n_node))
        sender_idx = offdiagonal_sender_idx(n_node)
        diffs = diffs[..., sender_idx, receiver_idx, :]
    return diffs


[docs] def GraphEdgeBuilder( mask_self, ): r""" Create a function that builds graph edges. Args: mask_self (bool): whether to mask edges that begin and end in the same node. """ def build(pos_sender, pos_receiver): r""" Build graph edges. Args: pos_sender (float, (:math:`N_{nodes}`, 3)): coordinates of graph nodes that send edges. pos_receiver (float, (:math:`M_{nodes}`, 3)): coordinates of graph nodes that receive edges. Returns: A :class:`~deepqmc.gnn.graph.GraphEdges` instance. """ assert pos_sender.shape[-1] == 3 and pos_receiver.shape[-1] == 3 assert len(pos_sender.shape) == 2 assert not mask_self or pos_sender.shape[0] == pos_receiver.shape[0] return compute_edges(pos_sender, pos_receiver, mask_self) return build
[docs] def MolecularGraphEdgeBuilder(n_nuc, n_up, n_down, edge_types, *, self_interaction): r""" Create a function that builds many types of molecular edges. Args: n_nuc (int): number of nuclei. n_up (int): number of spin-up electrons. n_down (int): number of spin-down electrons. edge_types (list[str]): list of edge type names to build. Possible names are: - ``'nn'``: nuclei->nuclei edges - ``'ne'``: nuclei->electrons edges - ``'en'``: electrons->nuclei edges - ``'same'``: edges between same-spin electrons - ``'anti'``: edges between opposite-spin electrons - ``'up'``: edges going from spin-up electrons to all electrons - ``'down'``: edges going from spin-down electrons to all electrons self_interaction (bool): whether edges between a particle and itself are considered """ builder_mapping = { 'nn': ['nn'], 'ne': ['ne'], 'en': ['en'], 'same': ['uu', 'dd'], 'anti': ['ud', 'du'], 'up': ['up'], 'down': ['down'], } fix_kwargs_of_builder_type = { 'nn': { 'mask_self': not self_interaction, }, 'ne': { 'mask_self': False, }, 'en': { 'mask_self': False, }, 'uu': { 'mask_self': not self_interaction, }, 'dd': { 'mask_self': not self_interaction, }, 'ud': { 'mask_self': False, }, 'du': { 'mask_self': False, }, 'up': {'mask_self': False}, 'down': {'mask_self': False}, } builders = { builder_type: GraphEdgeBuilder( **fix_kwargs_of_builder_type[builder_type], ) for edge_type in edge_types for builder_type in builder_mapping[edge_type] } build_rules = { 'nn': lambda pc: SimpleGraphEdges(builders['nn'](pc.R, pc.R)), 'ne': lambda pc: SimpleGraphEdges(builders['ne'](pc.R, pc.r)), 'en': lambda pc: SimpleGraphEdges(builders['en'](pc.r, pc.R)), 'same': lambda pc: SameGraphEdges( builders['uu'](pc.r[:n_up], pc.r[:n_up]), builders['dd'](pc.r[n_up:], pc.r[n_up:]), ), 'anti': lambda pc: AntiGraphEdges( builders['du'](pc.r[n_up:], pc.r[:n_up]), builders['ud'](pc.r[:n_up], pc.r[n_up:]), ), 'up': lambda pc: UpGraphEdges(builders['up'](pc.r[:n_up], pc.r)), 'down': lambda pc: DownGraphEdges(builders['down'](pc.r[n_up:], pc.r)), } def build(phys_conf): r""" Build many types of molecular graph edges. Args: phys_conf (~deepqmc.types.PhysicalConfiguration): the physical configuration of the molecule. occupancies (dict): mapping of edge type names to arrays where the occupancy of the given edge type is stored. """ assert phys_conf.r.shape[0] == n_up + n_down edges = { edge_type: build_rules[edge_type](phys_conf) for edge_type in edge_types } return edges return build
[docs] def GraphUpdate( aggregate_edges_for_nodes_fn, update_nodes_fn=None, update_edges_fn=None, ): r""" Create a function that updates a graph. The update function is tailored to be used in GNNs. Args: aggregate_edges_for_nodes_fn (bool): whether to perform the aggregation of edges for nodes. update_nodes_fn (~collections.abc.Callable): optional, function that updates the nodes. update_edges_fn (~collections.abc.Callable): optional, function that updates the edges. """ def update_graph(graph): nodes, edges = graph if update_nodes_fn: aggregated_edges = aggregate_edges_for_nodes_fn(nodes, edges) nodes = update_nodes_fn(nodes, aggregated_edges) if update_edges_fn: edges = update_edges_fn(edges) return Graph(nodes, edges) return update_graph
class GraphEdges: @property def single_array(self): raise NotImplementedError def update_from_single_array(self, array): raise NotImplementedError def sum_senders(self, normalize=False): raise NotImplementedError def convolve(self, nodes, normalize=False): raise NotImplementedError @jdc.pytree_dataclass class SimpleGraphEdges(GraphEdges): edges: jax.Array @property def single_array(self): return self.edges def update_from_single_array(self, array): return self.__class__(array) def sum_senders(self, normalize=False): return (jnp.mean if normalize else jnp.sum)(self.edges, axis=-3) def convolve(self, nodes, normalize=False): edge_node_product = self.edges * nodes[:, None] return self.__class__(edge_node_product).sum_senders(normalize) @jdc.pytree_dataclass class UpGraphEdges(SimpleGraphEdges): def convolve(self, nodes, normalize=False): up = self.edges * nodes[: self.edges.shape[-3], None] return self.__class__(up).sum_senders(normalize) @jdc.pytree_dataclass class DownGraphEdges(SimpleGraphEdges): def convolve(self, nodes, normalize=False): down = self.edges * nodes[-self.edges.shape[-3] :, None] return self.__class__(down).sum_senders(normalize) @jdc.pytree_dataclass class SameGraphEdges(GraphEdges): uu: jax.Array dd: jax.Array @property def single_array(self): batch_dims = self.uu.shape[:-3] return jnp.concatenate( [ self.uu.reshape(*batch_dims, -1, self.uu.shape[-1]), self.dd.reshape(*batch_dims, -1, self.dd.shape[-1]), ], axis=-2, ) def update_from_single_array(self, array): n_up = self.uu.shape[-2] n_down = self.dd.shape[-2] n_sender_up = self.uu.shape[-3] n_sender_down = self.dd.shape[-3] uu, dd = jnp.split(array, (n_up * n_sender_up,), axis=-2) uu = uu.reshape(*uu.shape[:-2], n_sender_up, n_up, uu.shape[-1]) dd = dd.reshape(*dd.shape[:-2], n_sender_down, n_down, dd.shape[-1]) return self.__class__(uu, dd) def sum_senders(self, normalize=False): norm_uu, norm_dd = ( max(x.shape[-3], 1) if normalize else 1 for x in (self.uu, self.dd) ) up, down = ( jnp.sum(self.uu, axis=-3) / norm_uu, jnp.sum(self.dd, axis=-3) / norm_dd, ) return jnp.concatenate([up, down], axis=-2) def convolve(self, nodes, normalize=False): self_interaction = self.uu.shape[-3] == self.uu.shape[-2] up_node_idx = ( (slice(None, self.uu.shape[-2]), None) if self_interaction else offdiagonal_sender_idx(self.uu.shape[-2]) ) down_node_idx = ( (slice(self.uu.shape[-2], None), None) if self_interaction else self.uu.shape[-2] + offdiagonal_sender_idx(self.dd.shape[-2]) ) uu = self.uu * nodes[up_node_idx] dd = self.dd * nodes[down_node_idx] return self.__class__(uu, dd).sum_senders(normalize) @jdc.pytree_dataclass class AntiGraphEdges(GraphEdges): du: jax.Array ud: jax.Array @property def single_array(self): batch_dims = self.du.shape[:-3] return jnp.concatenate( [ self.du.reshape(*batch_dims, -1, self.du.shape[-1]), self.ud.reshape(*batch_dims, -1, self.ud.shape[-1]), ], axis=-2, ) def update_from_single_array(self, array): n_up = self.du.shape[-2] n_down = self.ud.shape[-2] du, ud = jnp.split(array, (n_up * n_down,)) du = du.reshape(*du.shape[:-2], n_down, n_up, du.shape[-1]) ud = ud.reshape(*ud.shape[:-2], n_up, n_down, ud.shape[-1]) return self.__class__(du, ud) def sum_senders(self, normalize=False): norm_du, norm_ud = ( max(x.shape[-3], 1) if normalize else 1 for x in (self.du, self.ud) ) up, down = ( jnp.sum(self.du, axis=-3) / norm_du, jnp.sum(self.ud, axis=-3) / norm_ud, ) return jnp.concatenate([up, down], axis=-2) def convolve(self, nodes, normalize=False): du = self.du * nodes[self.du.shape[-2] :, None] ud = self.ud * nodes[: self.du.shape[-2], None] return self.__class__(du, ud).sum_senders(normalize)