Source code for deepqmc.gnn.electron_gnn

from functools import partial
from itertools import accumulate

import haiku as hk
import jax.numpy as jnp
from jax import tree_util

from ..hkext import MLP
from .graph import Graph, GraphNodes, GraphUpdate, MolecularGraphEdgeBuilder
from .utils import NodeEdgeMapping


[docs]class ElectronGNNLayer(hk.Module): r""" The message passing layer of :class:`ElectronGNN`. Derived from :class:`~deepqmc.gnn.gnn.MessagePassingLayer`. Args: one_particle_residual: whether a residual connection is used when updating the one particle embeddings, either :data:`False`, or an instance of :class:`~deepqmc.hkext.ResidualConnection`. two_particle_residual: whether a residual connection is used when updating the two particle embeddings, either :data:`False`, or an instance of :class:`~deepqmc.hkext.ResidualConnection`. deep_features: if :data:`False`, the edge features are not updated throughout the GNN layers, if :data:`shared` than in each layer a single MLP (:data:`u`) is used to update all edge types, if :data:`separate` then in each layer separate MLPs are used to update the different edge types. update_features (List[~deepqmc.gnn.update_features.UpdateFeature]): a list of partially initialized update feature classes to use when computing the update features of the one particle embeddings. For more details see the documentation of :class:`deepqmc.gnn.update_features`. update_rule (str): how to combine the update features for the update of the one particle embeddings. Possible values: - ``'concatenate'``: run concatenated features through MLP - ``'featurewise'``: apply different MLP to each feature channel and sum - ``'featurewise_shared'``: apply the same MLP across feature channels - ``'sum'``: sum features before sending through an MLP note that :data:`'sum'` and :data:`'featurewise_shared'` imply features of same size. subnet_factory (Callable): A function that constructs the subnetworks of the GNN layer. subnet_factory_by_lbl (dict): optional, a dictionary of functions that construct subnetworks of the GNN layer. If both this and :data:`subnet_factory` is specified, the specified values of :data:`subnet_factory_by_lbl` will take precedence. If some keys are missing, the default value of :data:`subnet_factory` will be used in their place. Possible keys are: (:data:`w`, :data:`h`, :data:`g` or :data:`u`). """ def __init__( self, n_interactions, ilayer, n_nuc, n_up, n_down, embedding_dim, edge_types, self_interaction, node_data, two_particle_stream_dim, *, one_particle_residual, two_particle_residual, deep_features, update_features, update_rule, subnet_factory=None, subnet_factory_by_lbl=None, ): super().__init__() self.n_nuc, self.n_up, self.n_down = n_nuc, n_up, n_down self.last_layer = ilayer == n_interactions - 1 self.edge_types = tuple( typ for typ in edge_types if not self.last_layer or typ not in {'nn', 'en'} ) self.mapping = NodeEdgeMapping(self.edge_types, node_data=node_data) assert update_rule in [ 'concatenate', 'featurewise', 'featurewise_shared', 'sum', ] assert ( update_rule not in ['sum', 'featurewise_shared'] or embedding_dim == two_particle_stream_dim ) assert deep_features in [False, 'shared', 'separate'] self.deep_features = deep_features self.update_rule = update_rule subnet_factory_by_lbl = subnet_factory_by_lbl or {} for lbl in ['g', 'u']: subnet_factory_by_lbl.setdefault(lbl, subnet_factory) if deep_features: self.u = ( subnet_factory_by_lbl['u'](two_particle_stream_dim, name='u') if deep_features == 'shared' else { typ: subnet_factory_by_lbl['u']( two_particle_stream_dim, name=f'u{typ}', ) for typ in self.edge_types } ) self.update_features = [ uf(self.n_up, self.n_down, two_particle_stream_dim, self.mapping) for uf in update_features ] self.g = ( subnet_factory_by_lbl['g']( embedding_dim, name='g', ) if not self.update_rule == 'featurewise' else { name: subnet_factory_by_lbl['g']( embedding_dim, name=f'g_{name}', ) for uf in (self.update_features) for name in uf.names } ) self.one_particle_residual = one_particle_residual self.two_particle_residual = two_particle_residual self.self_interaction = self_interaction def get_update_edges_fn(self): def update_edges(edges): if self.deep_features: if self.deep_features == 'shared': # combine features along leading dim, apply MLP and split # into channels again to please kfac keys, edge_objects = zip(*edges.items()) feats = [e.single_array for e in edge_objects] split_idxs = list(accumulate(len(f) for f in feats)) feats = jnp.split(self.u(jnp.concatenate(feats)), split_idxs) edge_objects = [ e.update_from_single_array(f) for e, f in zip(edge_objects, feats) ] updated_edges = dict(zip(keys, edge_objects)) elif self.deep_features == 'separate': updated_edges = { typ: edge.update_from_single_array( self.u[typ](edge.single_array) ) for typ, edge in edges.items() } if self.two_particle_residual: updated_edges = self.two_particle_residual(edges, updated_edges) return updated_edges else: return edges return update_edges def get_aggregate_edges_for_nodes_fn(self): def aggregate_edges_for_nodes(nodes, edges): f = [] for uf in self.update_features: f.extend(uf(nodes, edges)) return f return aggregate_edges_for_nodes def get_update_nodes_fn(self): def update_nodes(nodes, f): if self.update_rule == 'concatenate': updated = self.g(jnp.concatenate(f, axis=-1)) elif self.update_rule == 'featurewise': updated = sum(self.g[name](fi) for fi, name in zip(f, self.g.keys())) elif self.update_rule == 'sum': updated = self.g(sum(f)) elif self.update_rule == 'featurewise_shared': updated = jnp.sum(self.g(jnp.stack(f)), axis=0) if self.one_particle_residual: updated = self.one_particle_residual(nodes.electrons, updated) nodes = GraphNodes(nodes.nuclei, updated) return nodes return update_nodes def __call__(self, graph): r""" Execute the message passing layer. Args: graph (:class:`Graph`) Returns: :class:`Graph`: updated graph """ update_graph = GraphUpdate( update_nodes_fn=self.get_update_nodes_fn(), update_edges_fn=None if self.last_layer else self.get_update_edges_fn(), aggregate_edges_for_nodes_fn=self.get_aggregate_edges_for_nodes_fn(), ) return update_graph(graph)
[docs]class ElectronGNN(hk.Module): r""" A neural network acting on graphs defined by electrons and nuclei. Derived from :class:`~deepqmc.gnn.gnn.GraphNeuralNetwork`. Args: hamil (:class:`~deepqmc.MolecularHamiltonian`): the Hamiltonian of the system on which the graph is defined. embedding_dim (int): the length of the electron embedding vectors. n_interactions (int): number of message passing interactions. edge_features (dict): a :data:`dict` of functions for each edge type, embedding the interparticle differences. Valid keys are: - ``'ne'``: for nucleus-electron edges - ``'nn'``: for nucleus-nucleus edges - ``'same'``: for same spin electron-electron edges - ``'anti'``: for opposite spin electron-electron edges - ``'up'``: for edges going from spin up electrons to all electrons - ``'down'``: for edges going from spin down electrons to all electrons self_interaction (bool): whether to consider edges where the sender and receiver electrons are the same. two_particle_stream_dim (int): the feature dimension of the two particle streams. Only active if :data:`deep_features` are used. nuclei_embedding (Union[None,~deepqmc.gnn.electron_gnn.NucleiEmbedding]): optional, the instance responsible for creating the initial nuclear embeddings. Set to :data:`None` if nuclear embeddings are not needed. electron_embedding (~deepqmc.gnn.electron_gnn.ElectronEmbedding): the instance that creates the initial electron embeddings. layer_factory (Callable): a callable that generates a layer of the GNN. ghost_coords: optional, specifies the coordinates of one or more ghost atoms, useful for breaking spatial symmetries of the nuclear geometry. """ def __init__( self, hamil, embedding_dim, *, n_interactions, edge_features, self_interaction, two_particle_stream_dim, nuclei_embedding, electron_embedding, layer_factory, ghost_coords=None, ): super().__init__() n_nuc, n_up, n_down = hamil.n_nuc, hamil.n_up, hamil.n_down n_atom_types = hamil.mol.n_atom_types charges = hamil.mol.charges self.ghost_coords = None if ghost_coords is not None: charges = jnp.concatenate([charges, jnp.zeros(len(ghost_coords))]) n_nuc += len(ghost_coords) n_atom_types += 1 self.ghost_coords = jnp.asarray(ghost_coords) self.n_nuc, self.n_up, self.n_down = n_nuc, n_up, n_down self.embedding_dim = embedding_dim self.node_data = { 'n_nodes': {'nuclei': n_nuc, 'electrons': n_up + n_down}, 'n_node_types': {'electrons': 1 if n_up == n_down else 2}, 'node_types': { 'electrons': jnp.array(n_up * [0] + n_down * [int(n_up != n_down)]) }, } self.edge_types = tuple((edge_features or {}).keys()) self.layers = [ layer_factory( n_interactions, ilayer, n_nuc, n_up, n_down, embedding_dim, self.edge_types, self_interaction, self.node_data, two_particle_stream_dim, ) for ilayer in range(n_interactions) ] self.edge_features = edge_features self.nuclei_embedding = ( nuclei_embedding(charges, n_atom_types) if nuclei_embedding else None ) self.electron_embedding = electron_embedding( n_nuc, n_up, n_down, embedding_dim, self.node_data['n_node_types']['electrons'], self.node_data['node_types']['electrons'], ) self.self_interaction = self_interaction def node_factory(self, phys_conf): electron_embedding = self.electron_embedding(phys_conf) nucleus_embedding = self.nuclei_embedding() if self.nuclei_embedding else None return GraphNodes(nucleus_embedding, electron_embedding)
[docs] def edge_factory(self, phys_conf): r"""Compute all the graph edges used in the GNN.""" edge_factory = MolecularGraphEdgeBuilder( self.n_nuc, self.n_up, self.n_down, self.edge_types, self_interaction=self.self_interaction, ) edges = edge_factory(phys_conf) return { typ: edges[typ].update_from_single_array( self.edge_features[typ](edges[typ].single_array) ) for typ in self.edge_types }
def __call__(self, phys_conf): r""" Execute the graph neural network. Args: phys_conf (PhysicalConfiguration): the physical configuration of the molecule. Returns: float, (:math:`N_\text{elec}`, :data:`embedding_dim`): the final embeddings of the electrons. """ if self.ghost_coords is not None: phys_conf = phys_conf._replace( R=jnp.concatenate( [ phys_conf.R, jnp.tile(self.ghost_coords[None], (len(phys_conf.R), 1, 1)), ], axis=-2, ) ) graph_edges = self.edge_factory(phys_conf) graph_nodes = self.node_factory(phys_conf) graph = Graph(graph_nodes, graph_edges) for layer in self.layers: graph = layer(graph) return graph.nodes.electrons
[docs]class NucleiEmbedding(hk.Module): r"""Create initial embeddings for nuclei. Args: charges (jnp.ndarray): the nuclear charges of the molecule. n_atom_types (int): the number of different atom types in the molecule. embedding_dim (int): the length of the output embedding vector atom_type_embedding (bool): if :data:`True`, initial embeddings are the same for atoms of the same type (nuclear charge), otherwise they are different for all nuclei. subnet_type (str): the type of subnetwork to use for the embedding generation: - ``'mlp'``: an MLP is used - ``'embed'``: a :class:`haiku.Embed` block is used """ def __init__( self, charges, n_atom_types, *, embedding_dim, atom_type_embedding, subnet_type ): super().__init__() assert subnet_type in ['mlp', 'embed'] n_nuc_types = n_atom_types if atom_type_embedding else len(charges) if subnet_type == 'mlp': self.subnet = MLP( embedding_dim, hidden_layers=['log', 1], bias=True, last_linear=False, activation=jnp.tanh, init='deeperwin', ) elif subnet_type == 'embed': self.subnet = hk.Embed(n_nuc_types, embedding_dim) self.input = ( jnp.arange(len(charges)) if not atom_type_embedding else ( charges if subnet_type == 'mlp' else jnp.unique(charges, size=len(charges), return_inverse=True)[-1] ) ) if subnet_type == 'mlp': self.input = self.input[:, None] def __call__(self): return self.subnet(self.input)
[docs]class ElectronEmbedding(hk.Module): r"""Create initial embeddings for electrons. Args: n_nuc (int): the number of nuclei. n_up (int): the number of spin up electrons. n_down (int): the number of spin down electrons. embedding_dim (int): the desired length of the embedding vectors. n_elec_types (int): the number of electron types to differentiate. Usual values are: - ``1``: treat all electrons as indistinguishable. Note that electrons with different spins can still become distinguishable during the later embedding update steps of the GNN. - ``2``: treat spin up and spin down electrons as distinguishable already in the initial embeddings. elec_types (jax.Array): an integer array with length equal to the number of electrons, with entries between ``0`` and ``n_elec_types``. Specifies the type for each electron. positional_embeddings (Union[Literal[False],dict]): if not ``False``, a ``dict`` with edge types as keys, and edge features as values. Specifies the edge types and edge features to use when constructing the positional initial electron embeddings. use_spin (bool): only relevant if ``positional_embeddings`` is not ``False``, if ``True``, concatenate the spin of the given electron after the positional embedding features. project_to_embedding_dim (bool): only relevant if ``positional_embeddings`` is not ``False``, if ``True``, use a linear layer to project the initial embeddings to have length ``embedding_dim``. """ def __init__( self, n_nuc, n_up, n_down, embedding_dim, n_elec_types, elec_types, *, positional_embeddings, use_spin, project_to_embedding_dim, ): super().__init__() self.n_nuc = n_nuc self.n_up = n_up self.n_down = n_down self.embedding_dim = embedding_dim self.n_elec_types = n_elec_types self.elec_types = elec_types self.positional_embeddings = positional_embeddings self.use_spin = use_spin self.project_to_embedding_dim = project_to_embedding_dim def __call__(self, phys_conf): if self.positional_embeddings: edge_factory = MolecularGraphEdgeBuilder( self.n_nuc, self.n_up, self.n_down, self.positional_embeddings.keys(), self_interaction=False, ) feats = tree_util.tree_map( lambda f, e: f(e.single_array) .swapaxes(0, 1) .reshape(self.n_up + self.n_down, -1), self.positional_embeddings, edge_factory(phys_conf), ) x = tree_util.tree_reduce(partial(jnp.concatenate, axis=1), feats) if self.use_spin: spins = jnp.concatenate([jnp.ones(self.n_up), -jnp.ones(self.n_down)])[ :, None ] x = jnp.concatenate([x, spins], axis=1) if self.project_to_embedding_dim: x = hk.Linear(self.embedding_dim, with_bias=False)(x) else: X = hk.Embed( self.n_elec_types, self.embedding_dim, name='ElectronicEmbedding' ) x = X(self.elec_types) return x