Source code for deepqmc.gnn.update_features

from typing import Sequence

import haiku as hk
import jax.numpy as jnp

from ..hkext import Identity


[docs]class UpdateFeature(hk.Module): r"""Base class for all update features.""" def __init__(self, n_up, n_down, two_particle_stream_dim, node_edge_mapping): super().__init__() self.n_up = n_up self.n_down = n_down self.node_edge_mapping = node_edge_mapping self.two_particle_stream_dim = two_particle_stream_dim @property def names(self) -> Sequence[str]: raise NotImplementedError def __call__(self, nodes, edges) -> Sequence[jnp.ndarray]: raise NotImplementedError
[docs]class ResidualUpdateFeature(UpdateFeature): r"""Residual update feature. Returns the unchanged electron embeddings from the previous layer as a single update feature. """ def __call__(self, nodes, edges) -> Sequence[jnp.ndarray]: return [nodes.electrons] @property def names(self): return ['residual']
[docs]class NodeSumUpdateFeature(UpdateFeature): r"""The (normalized) sum of the node embeddings as an update feature. Returns the (normalized) sum of the electron embeddings from the previous layer as a single update feature. """ def __init__(self, *args, node_types, normalize): assert all(node_type in {'up', 'down'} for node_type in node_types) super().__init__(*args) self.normalize = normalize self.node_types = node_types def __call__(self, nodes, edges) -> Sequence[jnp.ndarray]: node_idx = {'up': slice(None, self.n_up), 'down': slice(self.n_up, None)} reduce_fn = jnp.mean if self.normalize else jnp.sum return [ jnp.tile( reduce_fn(nodes.electrons[node_idx[node_type]], axis=0, keepdims=True), (self.n_up + self.n_down, 1), ) for node_type in self.node_types ] @property def names(self): return [f'node_{node_type}' for node_type in self.node_types]
[docs]class EdgeSumUpdateFeature(UpdateFeature): r"""The (normalized) sum of the edge embeddings as an update feature. Returns the (normalized) sum of the edge embeddings for various edge types as separate update features. """ def __init__(self, *args, edge_types, normalize): assert all( edge_type in {'up', 'down', 'same', 'anti', 'ee', 'ne'} for edge_type in edge_types ) super().__init__(*args) self.normalize = normalize self.edge_types = edge_types def __call__(self, nodes, edges) -> Sequence[jnp.ndarray]: updates = [] for edge_type in self.edge_types: if edge_type == 'ee': factor = self.n_up + self.n_down if self.normalize else 1.0 updates.append( ( edges['same'].sum_senders(False) + edges['anti'].sum_senders(False) ) / factor ) else: updates.append(edges[edge_type].sum_senders(self.normalize)) return updates @property def names(self): return [f'edge_{edge_type}' for edge_type in self.edge_types]
[docs]class ConvolutionUpdateFeature(UpdateFeature): r"""The convolution of node and edge embeddings as an update feature. Returns the convolution of the node and edge embeddings for various edge types as separate update features. """ def __init__( self, *args, edge_types, normalize, w_factory, h_factory, w_for_ne=True ): assert all( edge_type in {'up', 'down', 'same', 'anti', 'ee', 'ne'} for edge_type in edge_types ) super().__init__(*args) self.normalize = normalize self.edge_types = edge_types layer_types = [typ for typ in edge_types if typ != 'ee'] if 'ee' in edge_types: layer_types.extend(['same', 'anti']) self.h_factory = h_factory self.w_factory = w_factory self.w_for_ne = w_for_ne def single_edge_type_update(self, nodes, edges, edge_type, normalize): w = ( self.w_factory(self.two_particle_stream_dim, name=f'w_{edge_type}') if self.w_for_ne or edge_type != 'ne' else Identity() ) we = w(edges[edge_type].single_array) h = self.h_factory(we.shape[-1], name=f'h_{edge_type}') hx = h(self.node_edge_mapping.sender_data_of(edge_type, nodes)) if edges[edge_type].single_array.size == 0: # parameters acting on size zero arrays cause NaN gradients return jnp.zeros((hx.shape[0], self.two_particle_stream_dim)) return edges[edge_type].update_from_single_array(we).convolve(hx, normalize) def __call__(self, nodes, edges) -> Sequence[jnp.ndarray]: updates = [] for edge_type in self.edge_types: if edge_type == 'ee': ee = sum( self.single_edge_type_update(nodes, edges, st, False) for st in ['same', 'anti'] ) factor = self.n_up + self.n_down if self.normalize else 1.0 updates.append(ee / factor) else: updates.append( self.single_edge_type_update( nodes, edges, edge_type, self.normalize ) ) return updates @property def names(self): return [f'conv_{edge_type}' for edge_type in self.edge_types]
[docs]class NodeAttentionUpdateFeature(UpdateFeature): r"""Create a single update feature by attenting over the nodes. Returns the Psiformer update feature based on attention over the nodes. """ def __init__(self, *args, num_heads, mlp_factory, attention_residual, mlp_residual): super().__init__(*args) self.num_heads = num_heads self.attention_residual = attention_residual self.mlp_residual = mlp_residual self.mlp_factory = mlp_factory def __call__(self, nodes, edges): h = nodes.electrons heads_dim = h.shape[-1] // self.num_heads assert heads_dim * self.num_heads == h.shape[-1] attention_layer = hk.MultiHeadAttention( self.num_heads, heads_dim, w_init=hk.initializers.VarianceScaling(1, 'fan_in', 'normal'), with_bias=False, ) mlp = self.mlp_factory(h.shape[-1], name='mlp') attended = attention_layer(h, h, h) if self.attention_residual: attended = self.attention_residual(h, attended) mlp_out = mlp(attended) if self.mlp_residual: mlp_out = self.mlp_residual(attended, mlp_out) return [mlp_out]