Source code for deepqmc.gnn.graph

from collections import namedtuple

import jax.numpy as jnp
import jax_dataclasses as jdc

GraphNodes = namedtuple('GraphNodes', 'nuclei electrons')
Graph = namedtuple('Graph', 'nodes edges')

__all__ = [
    'GraphEdgeBuilder',
    'MolecularGraphEdgeBuilder',
    'GraphUpdate',
]


def offdiagonal_sender_idx(n_node):
    return (
        jnp.arange(n_node)[None, :] <= jnp.arange(n_node - 1)[:, None]
    ) + jnp.arange(n_node - 1)[:, None]


def compute_edges(pos_sender, pos_receiver, filter_diagonal):
    diffs = pos_receiver[..., None, :, :] - pos_sender[..., None, :]
    if filter_diagonal:
        assert pos_sender.shape[-2] == pos_receiver.shape[-2]
        n_node = pos_sender.shape[-2]
        receiver_idx = jnp.broadcast_to(jnp.arange(n_node)[None], (n_node - 1, n_node))
        sender_idx = offdiagonal_sender_idx(n_node)
        diffs = diffs[..., sender_idx, receiver_idx, :]
    return diffs


[docs]def GraphEdgeBuilder( mask_self, offsets, mask_vals, ): r""" Create a function that builds graph edges. Args: filter_self (bool): whether to filter edges between nodes of the same index. offsets ((int, int)): node index offset to be added to the returned sender and receiver node indices respectively. mask_vals ((int, int)): if ``occupancy_limit`` is larger than the number of valid edges, the remaining node indices will be filled with these values for the sender and receiver nodes respectively (i.e. the value to pad the node index arrays with). feature_callback (Callable): a function that takes the sender positions, receiver positions, sender node indices and receiver node indices and returns some data (features) computed for the edges. """ def build(pos_sender, pos_receiver): r""" Build graph edges. Args: pos_sender (float, (:math:`N_{nodes}`, 3)): coordinates of graph nodes that send edges. pos_receiver (float, (:math:`M_{nodes}`, 3)): coordinates of graph nodes that receive edges. Returns: A :class:`~deepqmc.gnn.graph.GraphEdges` instance. """ assert pos_sender.shape[-1] == 3 and pos_receiver.shape[-1] == 3 assert len(pos_sender.shape) == 2 assert not mask_self or pos_sender.shape[0] == pos_receiver.shape[0] return compute_edges(pos_sender, pos_receiver, mask_self) return build
[docs]def MolecularGraphEdgeBuilder(n_nuc, n_up, n_down, edge_types, *, self_interaction): r""" Create a function that builds many types of molecular edges. Args: n_nuc (int): number of nuclei. n_up (int): number of spin-up electrons. n_down (int): number of spin-down electrons. edge_types (List[str]): list of edge type names to build. Possible names are: - ``'nn'``: nuclei->nuclei edges - ``'ne'``: nuclei->electrons edges - ``'en'``: electrons->nuclei edges - ``'same'``: edges between same-spin electrons - ``'anti'``: edges between opposite-spin electrons - ``'up'``: edges going from spin-up electrons to all electrons - ``'down'``: edges going from spin-down electrons to all electrons self_interaction (bool): whether edges between a particle and itself are considered """ n_elec = n_up + n_down builder_mapping = { 'nn': ['nn'], 'ne': ['ne'], 'en': ['en'], 'same': ['uu', 'dd'], 'anti': ['ud', 'du'], 'up': ['up'], 'down': ['down'], } fix_kwargs_of_builder_type = { 'nn': { 'mask_self': not self_interaction, 'offsets': (0, 0), 'mask_vals': (n_nuc, n_nuc), }, 'ne': { 'mask_self': False, 'offsets': (0, 0), 'mask_vals': (n_nuc, n_elec), }, 'en': { 'mask_self': False, 'offsets': (0, 0), 'mask_vals': (n_elec, n_nuc), }, 'uu': { 'mask_self': not self_interaction, 'offsets': (0, 0), 'mask_vals': (n_elec, n_elec), }, 'dd': { 'mask_self': not self_interaction, 'offsets': (n_up, n_up), 'mask_vals': (n_elec, n_elec), }, 'ud': { 'mask_self': False, 'mask_vals': (n_elec, n_elec), 'offsets': (0, n_up), }, 'du': { 'mask_self': False, 'mask_vals': (n_elec, n_elec), 'offsets': (n_up, 0), }, 'up': {'mask_self': False, 'offsets': (0, 0), 'mask_vals': (n_elec, n_elec)}, 'down': { 'mask_self': False, 'offsets': (n_up, 0), 'mask_vals': (n_elec, n_elec), }, } builders = { builder_type: GraphEdgeBuilder( **fix_kwargs_of_builder_type[builder_type], ) for edge_type in edge_types for builder_type in builder_mapping[edge_type] } build_rules = { 'nn': lambda pc: SimpleGraphEdges(builders['nn'](pc.R, pc.R)), 'ne': lambda pc: SimpleGraphEdges(builders['ne'](pc.R, pc.r)), 'en': lambda pc: SimpleGraphEdges(builders['en'](pc.r, pc.R)), 'same': lambda pc: SameGraphEdges( builders['uu'](pc.r[:n_up], pc.r[:n_up]), builders['dd'](pc.r[n_up:], pc.r[n_up:]), ), 'anti': lambda pc: AntiGraphEdges( builders['du'](pc.r[n_up:], pc.r[:n_up]), builders['ud'](pc.r[:n_up], pc.r[n_up:]), ), 'up': lambda pc: UpGraphEdges(builders['up'](pc.r[:n_up], pc.r)), 'down': lambda pc: DownGraphEdges(builders['down'](pc.r[n_up:], pc.r)), } def build(phys_conf): r""" Build many types of molecular graph edges. Args: phys_conf (~deepqmc.types.PhysicalConfiguration): the physical configuration of the molecule. occupancies (dict): mapping of edge type names to arrays where the occupancy of the given edge type is stored. """ assert phys_conf.r.shape[0] == n_up + n_down edges = { edge_type: build_rules[edge_type](phys_conf) for edge_type in edge_types } return edges return build
[docs]def GraphUpdate( aggregate_edges_for_nodes_fn, update_nodes_fn=None, update_edges_fn=None, ): r""" Create a function that updates a graph. The update function is tailored to be used in GNNs. Args: aggregate_edges_for_nodes_fn (bool): whether to perform the aggregation of edges for nodes. update_nodes_fn (Callable): optional, function that updates the nodes. update_edges_fn (Callable): optional, function that updates the edges. """ def update_graph(graph): nodes, edges = graph if update_nodes_fn: aggregated_edges = aggregate_edges_for_nodes_fn(nodes, edges) nodes = update_nodes_fn(nodes, aggregated_edges) if update_edges_fn: edges = update_edges_fn(edges) return Graph(nodes, edges) return update_graph
class GraphEdges: @property def single_array(self): raise NotImplementedError def update_from_single_array(self, array): raise NotImplementedError def sum_senders(self, normalize=False): raise NotImplementedError def convolve(self, nodes, normalize=False): raise NotImplementedError @jdc.pytree_dataclass class SimpleGraphEdges(GraphEdges): edges: jnp.ndarray @property def single_array(self): return self.edges def update_from_single_array(self, array): return self.__class__(array) def sum_senders(self, normalize=False): return (jnp.mean if normalize else jnp.sum)(self.edges, axis=-3) def convolve(self, nodes, normalize=False): edge_node_product = self.edges * nodes[:, None] return self.__class__(edge_node_product).sum_senders(normalize) @jdc.pytree_dataclass class UpGraphEdges(SimpleGraphEdges): def convolve(self, nodes, normalize=False): up = self.edges * nodes[: self.edges.shape[-3], None] return self.__class__(up).sum_senders(normalize) @jdc.pytree_dataclass class DownGraphEdges(SimpleGraphEdges): def convolve(self, nodes, normalize=False): down = self.edges * nodes[-self.edges.shape[-3] :, None] return self.__class__(down).sum_senders(normalize) @jdc.pytree_dataclass class SameGraphEdges(GraphEdges): uu: jnp.ndarray dd: jnp.ndarray @property def single_array(self): batch_dims = self.uu.shape[:-3] return jnp.concatenate( [ self.uu.reshape(*batch_dims, -1, self.uu.shape[-1]), self.dd.reshape(*batch_dims, -1, self.dd.shape[-1]), ], axis=-2, ) def update_from_single_array(self, array): n_up = self.uu.shape[-2] n_down = self.dd.shape[-2] n_sender_up = self.uu.shape[-3] n_sender_down = self.dd.shape[-3] uu, dd = jnp.split(array, (n_up * n_sender_up,), axis=-2) uu = uu.reshape(*uu.shape[:-2], n_sender_up, n_up, uu.shape[-1]) dd = dd.reshape(*dd.shape[:-2], n_sender_down, n_down, dd.shape[-1]) return self.__class__(uu, dd) def sum_senders(self, normalize=False): norm_uu, norm_dd = ( max(x.shape[-3], 1) if normalize else 1 for x in (self.uu, self.dd) ) up, down = ( jnp.sum(self.uu, axis=-3) / norm_uu, jnp.sum(self.dd, axis=-3) / norm_dd, ) return jnp.concatenate([up, down], axis=-2) def convolve(self, nodes, normalize=False): self_interaction = self.uu.shape[-3] == self.uu.shape[-2] up_node_idx = ( (slice(None, self.uu.shape[-2]), None) if self_interaction else offdiagonal_sender_idx(self.uu.shape[-2]) ) down_node_idx = ( (slice(self.uu.shape[-2], None), None) if self_interaction else self.uu.shape[-2] + offdiagonal_sender_idx(self.dd.shape[-2]) ) uu = self.uu * nodes[up_node_idx] dd = self.dd * nodes[down_node_idx] return self.__class__(uu, dd).sum_senders(normalize) @jdc.pytree_dataclass class AntiGraphEdges(GraphEdges): du: jnp.ndarray ud: jnp.ndarray @property def single_array(self): batch_dims = self.du.shape[:-3] return jnp.concatenate( [ self.du.reshape(*batch_dims, -1, self.du.shape[-1]), self.ud.reshape(*batch_dims, -1, self.ud.shape[-1]), ], axis=-2, ) def update_from_single_array(self, array): n_up = self.du.shape[-2] n_down = self.ud.shape[-2] du, ud = jnp.split(array, (n_up * n_down,)) du = du.reshape(*du.shape[:-2], n_down, n_up, du.shape[-1]) ud = ud.reshape(*ud.shape[:-2], n_up, n_down, ud.shape[-1]) return self.__class__(du, ud) def sum_senders(self, normalize=False): norm_du, norm_ud = ( max(x.shape[-3], 1) if normalize else 1 for x in (self.du, self.ud) ) up, down = ( jnp.sum(self.du, axis=-3) / norm_du, jnp.sum(self.ud, axis=-3) / norm_ud, ) return jnp.concatenate([up, down], axis=-2) def convolve(self, nodes, normalize=False): du = self.du * nodes[self.du.shape[-2] :, None] ud = self.ud * nodes[: self.du.shape[-2], None] return self.__class__(du, ud).sum_senders(normalize)