Source code for deepqmc.gnn.update_features

from collections.abc import Mapping, Sequence

import haiku as hk
import jax.numpy as jnp

from ..hkext import Identity
from .graph import GraphEdges, GraphNodes
from .utils import NodeEdgeMapping


[docs] class UpdateFeature(hk.Module): r"""Base class for all update features. Args: n_up (int): number of spin up electrons n_down (int): number of spin down electrons two_particle_stream_dim (int): dimension of the two-particle stream node_edge_mapping (~deepqmc.gnn.utils.NodeEdgeMapping): mapping between the various node and edge types. """ def __init__( self, n_up: int, n_down: int, two_particle_stream_dim: int, node_edge_mapping: NodeEdgeMapping, ): super().__init__() self.n_up = n_up self.n_down = n_down self.node_edge_mapping = node_edge_mapping self.two_particle_stream_dim = two_particle_stream_dim @property def names(self) -> list[str]: raise NotImplementedError def __call__( self, nodes: GraphNodes, edges: Mapping[str, GraphEdges] ) -> Sequence[GraphNodes]: raise NotImplementedError
[docs] class ResidualElectronUpdateFeature(UpdateFeature): r"""Residual update feature. Returns the unchanged electron embeddings from the previous layer as a single update feature. """ def __call__( self, nodes: GraphNodes, edges: Mapping[str, GraphEdges] ) -> Sequence[GraphNodes]: return [GraphNodes(None, nodes.electrons)] @property def names(self) -> list[str]: return ['residual']
[docs] class NodeSumElectronUpdateFeature(UpdateFeature): r"""The (normalized) sum of the node embeddings as an update feature. Returns the (normalized) sum of the electron embeddings from the previous layer as a single update feature. Args: n_up (int): number of spin up electrons n_down (int): number of spin down electrons two_particle_stream_dim (int): dimension of the two-particle stream node_edge_mapping (~deepqmc.gnn.utils.NodeEdgeMapping): mapping between the various node and edge types. node_types (list[str]): list of node types to update normalize (bool): whether to normalize the sum by the number of nodes """ def __init__(self, *args, node_types, normalize): assert all(node_type in {'up', 'down'} for node_type in node_types) super().__init__(*args) self.normalize = normalize self.node_types = node_types def __call__( self, nodes: GraphNodes, edges: Mapping[str, GraphEdges] ) -> Sequence[GraphNodes]: node_idx = {'up': slice(None, self.n_up), 'down': slice(self.n_up, None)} reduce_fn = jnp.mean if self.normalize else jnp.sum return [ GraphNodes( None, jnp.tile( reduce_fn( nodes.electrons[node_idx[node_type]], axis=0, keepdims=True ), (self.n_up + self.n_down, 1), ), ) for node_type in self.node_types ] @property def names(self) -> list[str]: return [f'node_{node_type}' for node_type in self.node_types]
[docs] class EdgeSumElectronUpdateFeature(UpdateFeature): r"""The (normalized) sum of the edge embeddings as an update feature. Returns the (normalized) sum of the edge embeddings for various edge types as separate update features. Args: n_up (int): number of spin up electrons n_down (int): number of spin down electrons two_particle_stream_dim (int): dimension of the two-particle stream node_edge_mapping (~deepqmc.gnn.utils.NodeEdgeMapping): mapping between the various node and edge types. edge_types (list[str]): list of edge types to sum over normalize (bool): whether to normalize the sum by the number of senders """ def __init__(self, *args, edge_types, normalize): assert all( edge_type in {'up', 'down', 'same', 'anti', 'ee', 'ne'} for edge_type in edge_types ) super().__init__(*args) self.normalize = normalize self.edge_types = edge_types def __call__( self, nodes: GraphNodes, edges: Mapping[str, GraphEdges] ) -> Sequence[GraphNodes]: updates = [] for edge_type in self.edge_types: if edge_type == 'ee': factor = self.n_up + self.n_down if self.normalize else 1.0 updates.append( GraphNodes( None, ( edges['same'].sum_senders(False) + edges['anti'].sum_senders(False) ) / factor, ) ) else: updates.append( GraphNodes(None, edges[edge_type].sum_senders(self.normalize)) ) return updates @property def names(self) -> list[str]: return [f'edge_{edge_type}' for edge_type in self.edge_types]
[docs] class ConvolutionElectronUpdateFeature(UpdateFeature): r"""The convolution of node and edge embeddings as an update feature. Returns the convolution of the node and edge embeddings for various edge types as separate update features. Args: n_up (int): number of spin up electrons n_down (int): number of spin down electrons two_particle_stream_dim (int): dimension of the two-particle stream node_edge_mapping (~deepqmc.gnn.utils.NodeEdgeMapping): mapping between the various node and edge types. edge_types (list[str]): list of edge types to sum over normalize (bool): whether to normalize the sum by the number of senders w_factory (~collections.abc.Callable): factory function for the :math:`w` matrix h_factory (~collections.abc.Callable): factory function for the :math:`h` matrix w_for_ne (bool): whether to use the :math:`w` matrix for the :math:`ne` edge type """ def __init__( self, *args, edge_types, normalize, w_factory, h_factory, w_for_ne=True ): assert all( edge_type in {'up', 'down', 'same', 'anti', 'ee', 'ne'} for edge_type in edge_types ) super().__init__(*args) self.normalize = normalize self.edge_types = edge_types layer_types = [typ for typ in edge_types if typ != 'ee'] if 'ee' in edge_types: layer_types.extend(['same', 'anti']) self.h_factory = h_factory self.w_factory = w_factory self.w_for_ne = w_for_ne def single_edge_type_update(self, nodes, edges, edge_type, normalize): w = ( self.w_factory(self.two_particle_stream_dim, name=f'w_{edge_type}') if self.w_for_ne or edge_type != 'ne' else Identity() ) we = w(edges[edge_type].single_array) h = self.h_factory(we.shape[-1], name=f'h_{edge_type}') hx = h(self.node_edge_mapping.sender_data_of(edge_type, nodes)) if edges[edge_type].single_array.size == 0: # parameters acting on size zero arrays cause NaN gradients return jnp.zeros((hx.shape[0], self.two_particle_stream_dim)) return edges[edge_type].update_from_single_array(we).convolve(hx, normalize) def __call__( self, nodes: GraphNodes, edges: Mapping[str, GraphEdges] ) -> Sequence[GraphNodes]: updates = [] for edge_type in self.edge_types: if edge_type == 'ee': ee = sum( self.single_edge_type_update(nodes, edges, st, False) for st in ['same', 'anti'] ) factor = self.n_up + self.n_down if self.normalize else 1.0 updates.append(GraphNodes(None, ee / factor)) else: updates.append( GraphNodes( None, self.single_edge_type_update( nodes, edges, edge_type, self.normalize ), ) ) return updates @property def names(self) -> list[str]: return [f'conv_{edge_type}' for edge_type in self.edge_types]
[docs] class NodeAttentionElectronUpdateFeature(UpdateFeature): r"""Create a single update feature by attenting over the nodes. Returns the Psiformer update feature based on attention over the nodes. Args: n_up (int): number of spin up electrons n_down (int): number of spin down electrons two_particle_stream_dim (int): dimension of the two-particle stream node_edge_mapping (~deepqmc.gnn.utils.NodeEdgeMapping): mapping between the various node and edge types. num_heads (int): number of attention heads mlp_factory (~typing.Type[~deepqmc.hkext.MLP]): factory function for the MLP attention_residual (Optional[~deepqmc.hkext.Residual]): optional residual connection after the attention layer mlp_residual (Optional[~deepqmc.hkext.Residual]): optional residual connection after the MLP layer """ def __init__(self, *args, num_heads, mlp_factory, attention_residual, mlp_residual): super().__init__(*args) self.num_heads = num_heads self.attention_residual = attention_residual self.mlp_residual = mlp_residual self.mlp_factory = mlp_factory def __call__( self, nodes: GraphNodes, edges: Mapping[str, GraphEdges] ) -> Sequence[GraphNodes]: h = nodes.electrons heads_dim = h.shape[-1] // self.num_heads assert heads_dim * self.num_heads == h.shape[-1] attention_layer = hk.MultiHeadAttention( self.num_heads, heads_dim, w_init=hk.initializers.VarianceScaling(1, 'fan_in', 'normal'), with_bias=False, ) mlp = self.mlp_factory(h.shape[-1], name='mlp') attended = attention_layer(h, h, h) if self.attention_residual: attended = self.attention_residual(h, attended) mlp_out = mlp(attended) if self.mlp_residual: mlp_out = self.mlp_residual(attended, mlp_out) return [GraphNodes(None, mlp_out)]
class CombinedNodeAttentionUpdateFeature(UpdateFeature): r"""Create an attention update feature for both electrons and nuclei. The update feature is created by attending over both electrons and nuclei. Args: n_up (int): number of spin up electrons n_down (int): number of spin down electrons two_particle_stream_dim (int): dimension of the two-particle stream node_edge_mapping (~deepqmc.gnn.utils.NodeEdgeMapping): mapping between the various node and edge types. num_heads (int): number of attention heads mlp_factory (~typing.Type[~deepqmc.hkext.MLP]): factory function for the MLP attention_residual (~deepqmc.hkext.Residual): optional, optional residual connection after the attention layer mlp_residual (~deepqmc.hkext.Residual): optional, residual connection after the MLP layer elec_to_nuc (bool): whether to allow attention over the electrons to influence the nuclei """ def __init__( self, *args, num_heads, mlp_factory, attention_residual, mlp_residual, elec_to_nuc, ): super().__init__(*args) self.num_heads = num_heads self.attention_residual = attention_residual self.mlp_residual = mlp_residual self.mlp_factory = mlp_factory self.elec_to_nuc = elec_to_nuc def __call__( self, nodes: GraphNodes, edges: Mapping[str, GraphEdges] ) -> Sequence[GraphNodes]: n_nuc = len(nodes.nuclei) n_el = len(nodes.electrons) h = jnp.concatenate([nodes.nuclei, nodes.electrons], axis=0) mask = ( None if self.elec_to_nuc else jnp.ones((1, n_nuc + n_el, n_nuc + n_el), dtype=bool) .at[:, :n_nuc, n_nuc:] .set(False) ) heads_dim = h.shape[-1] // self.num_heads assert heads_dim * self.num_heads == h.shape[-1] attention_layer = hk.MultiHeadAttention( self.num_heads, heads_dim, w_init=hk.initializers.VarianceScaling(1, 'fan_in', 'normal'), with_bias=False, ) mlp = self.mlp_factory(h.shape[-1], name='mlp') attended = attention_layer(h, h, h, mask) if self.attention_residual: attended = self.attention_residual(h, attended) mlp_out = mlp(attended) if self.mlp_residual: mlp_out = self.mlp_residual(attended, mlp_out) nuclei_out, electrons_out = jnp.split(mlp_out, [nodes.nuclei.shape[0]], axis=0) return [GraphNodes(nuclei_out, electrons_out)]