from functools import partial
from itertools import accumulate
import haiku as hk
import jax
import jax.numpy as jnp
from jax import tree_util
from ..hkext import MLP
from .graph import Graph, GraphNodes, GraphUpdate, MolecularGraphEdgeBuilder
from .utils import NodeEdgeMapping
[docs]
class ElectronGNNLayer(hk.Module):
r"""
The message passing layer of :class:`ElectronGNN`.
Derived from :class:`~deepqmc.gnn.gnn.MessagePassingLayer`.
Args:
n_interactions (int): the number of message passing interactions.
ilayer (int): the index of this layer (0 <= ilayer < n_interactions).
n_nuc (int): the number of nuclei.
n_up (int): the number of spin up electrons.
n_down (int): the number of spin down electrons.
embedding_dim (int): the length of the electron embedding vectors.
edge_types (Tuple[str]): the types of edges to consider.
self_interaction (bool): whether to consider edges where the sender and
receiver electrons are the same.
node_data (Dict[str, Any]): a dictionary containing information about the
nodes of the graph.
two_particle_stream_dim (int): the feature dimension of the two particle
streams.
electron_residual: whether a residual connection is used when updating
the electron embeddings, either :data:`False`, or an instance of
:class:`~deepqmc.hkext.ResidualConnection`.
nucleus_residual: whether a residual connection is used when updating
the nucleus embeddings, either :data:`False`, or an instance of
:class:`~deepqmc.hkext.ResidualConnection`.
two_particle_residual: whether a residual connection is used when updating
the two particle embeddings, either :data:`False`, or an instance of
:class:`~deepqmc.hkext.ResidualConnection`.
deep_features: if :data:`False`, the edge features are not updated throughout
the GNN layers, if :data:`shared` than in each layer a single MLP
(:data:`u`) is used to update all edge types, if :data:`separate` then in
each layer separate MLPs are used to update the different edge types.
update_features (list[~deepqmc.gnn.update_features.UpdateFeature]): a list of
partially initialized update feature classes to use when computing the
update features of the one particle embeddings. For more details see the
documentation of :class:`deepqmc.gnn.update_features`.
update_rule (str): how to combine the update features for the update of the
one particle embeddings.
Possible values:
- ``'concatenate'``: run concatenated features through MLP
- ``'featurewise'``: apply different MLP to each feature channel and sum
- ``'featurewise_shared'``: apply the same MLP across feature channels
- ``'sum'``: sum features before sending through an MLP
note that :data:`'sum'` and :data:`'featurewise_shared'` imply features
of same size.
subnet_factory (~collections.abc.Callable): optional, a function that constructs
the subnetworks of the GNN layer.
subnet_factory_by_lbl (dict): optional, a dictionary of functions that construct
subnetworks of the GNN layer. If both this and :data:`subnet_factory` is
specified, the specified values of :data:`subnet_factory_by_lbl` will take
precedence. If some keys are missing, the default value of
:data:`subnet_factory` will be used in their place. Possible keys are:
(:data:`w`, :data:`h`, :data:`g` or :data:`u`).
"""
def __init__(
self,
n_interactions,
ilayer,
n_nuc,
n_up,
n_down,
embedding_dim,
edge_types,
self_interaction,
node_data,
two_particle_stream_dim,
*,
electron_residual,
nucleus_residual,
two_particle_residual,
deep_features,
update_features,
update_rule,
subnet_factory=None,
subnet_factory_by_lbl=None,
):
super().__init__()
self.n_nuc, self.n_up, self.n_down = n_nuc, n_up, n_down
self.last_layer = ilayer == n_interactions - 1
self.edge_types = tuple(
typ for typ in edge_types if not self.last_layer or typ not in {'nn', 'en'}
)
self.mapping = NodeEdgeMapping(self.edge_types, node_data=node_data)
assert update_rule in [
'concatenate',
'featurewise',
'featurewise_shared',
'sum',
]
assert (
update_rule not in ['sum', 'featurewise_shared']
or embedding_dim == two_particle_stream_dim
)
assert deep_features in [False, 'shared', 'separate']
self.deep_features = deep_features
self.update_rule = update_rule
subnet_factory_by_lbl = subnet_factory_by_lbl or {}
for lbl in ['g', 'u']:
subnet_factory_by_lbl.setdefault(lbl, subnet_factory)
if deep_features:
self.u = (
subnet_factory_by_lbl['u'](two_particle_stream_dim, name='u')
if deep_features == 'shared'
else {
typ: subnet_factory_by_lbl['u'](
two_particle_stream_dim,
name=f'u{typ}',
)
for typ in self.edge_types
}
)
self.update_features = [
uf(self.n_up, self.n_down, two_particle_stream_dim, self.mapping)
for uf in update_features
]
self.g_factory = subnet_factory_by_lbl['g']
self.g = (
self.g_factory(
embedding_dim,
name='g',
)
if not self.update_rule == 'featurewise'
else {
name: self.g_factory(
embedding_dim,
name=f'g_{name}',
)
for uf in self.update_features
for name in uf.names
}
)
self.electron_residual = electron_residual
self.nucleus_residual = nucleus_residual
self.two_particle_residual = two_particle_residual
self.self_interaction = self_interaction
def get_update_edges_fn(self):
def update_edges(edges):
if self.deep_features:
if self.deep_features == 'shared':
assert not isinstance(self.u, dict)
# combine features along leading dim, apply MLP and split
# into channels again to please kfac
keys, edge_objects = zip(*edges.items())
feats = [e.single_array for e in edge_objects]
split_idxs = list(accumulate(len(f) for f in feats))
feats = jnp.split(self.u(jnp.concatenate(feats)), split_idxs)
edge_objects = [
e.update_from_single_array(f)
for e, f in zip(edge_objects, feats)
]
updated_edges = dict(zip(keys, edge_objects))
elif self.deep_features == 'separate':
updated_edges = {
typ: edge.update_from_single_array(
self.u[typ](edge.single_array)
)
for typ, edge in edges.items()
}
else:
raise ValueError(f'Unknown deep features: {self.deep_features}')
if self.two_particle_residual:
updated_edges = self.two_particle_residual(edges, updated_edges)
return updated_edges
else:
return edges
return update_edges
def get_aggregate_edges_for_nodes_fn(self):
def aggregate_edges_for_nodes(nodes, edges):
fs = sum(
(uf(nodes, edges) for uf in self.update_features),
start=[],
)
return GraphNodes(
[f.nuclei for f in fs if f.nuclei is not None],
[f.electrons for f in fs if f.electrons is not None],
)
return aggregate_edges_for_nodes
def get_update_nodes_fn(self):
def update_nodes(nodes, update_features: GraphNodes):
updated_electrons = self.apply_update_rule(
nodes.electrons,
self.g,
update_features.electrons,
self.electron_residual,
)
if nodes.nuclei is not None and update_features.nuclei:
g_nuc = (
self.g_factory(
nodes.nuclei.shape[-1],
name='g_nuc',
)
if not self.update_rule == 'featurewise'
else {
name: self.g_factory(
nodes.nuclei.shape[-1],
name=f'g_nuc_{name}',
)
for uf in (update_features.nuclei)
for name in uf.names
}
)
updated_nuclei = self.apply_update_rule(
nodes.nuclei,
g_nuc,
update_features.nuclei,
self.nucleus_residual,
)
else:
updated_nuclei = nodes.nuclei
return GraphNodes(updated_nuclei, updated_electrons)
return update_nodes
def apply_update_rule(self, nodes, update_network, update_features, residual):
if self.update_rule == 'concatenate':
updated = update_network(jnp.concatenate(update_features, axis=-1))
elif self.update_rule == 'featurewise':
updated = sum(
update_network[name](fi)
for fi, name in zip(update_features, update_network.keys())
)
elif self.update_rule == 'sum':
updated = update_network(sum(update_features))
elif self.update_rule == 'featurewise_shared':
updated = jnp.sum(update_network(jnp.stack(update_features)), axis=0)
else:
raise ValueError(f'Unknown update rule: {self.update_rule}')
if residual:
updated = residual(nodes, updated)
return updated
def __call__(self, graph):
r"""
Execute the message passing layer.
Args:
graph (:class:`Graph`)
Returns:
:class:`Graph`: updated graph
"""
update_graph = GraphUpdate(
update_nodes_fn=self.get_update_nodes_fn(),
update_edges_fn=None if self.last_layer else self.get_update_edges_fn(),
aggregate_edges_for_nodes_fn=self.get_aggregate_edges_for_nodes_fn(),
)
return update_graph(graph)
[docs]
class ElectronGNN(hk.Module):
r"""
A neural network acting on graphs defined by electrons and nuclei.
Derived from :class:`~deepqmc.gnn.gnn.GraphNeuralNetwork`.
Args:
hamil (:class:`~deepqmc.hamil.MolecularHamiltonian`): the Hamiltonian of
the system on which the graph is defined.
embedding_dim (int): the length of the electron embedding vectors.
n_interactions (int): number of message passing interactions.
edge_features (dict): a :data:`dict` of functions for each edge
type, embedding the interparticle differences. Valid keys are:
- ``'ne'``: for nucleus-electron edges
- ``'nn'``: for nucleus-nucleus edges
- ``'same'``: for same spin electron-electron edges
- ``'anti'``: for opposite spin electron-electron edges
- ``'up'``: for edges going from spin up electrons to all electrons
- ``'down'``: for edges going from spin down electrons to all electrons
self_interaction (bool): whether to consider edges where the sender and
receiver electrons are the same.
two_particle_stream_dim (int): the feature dimension of the two particle
streams. Only active if :data:`deep_features` are used.
nuclei_embedding (~typing.Type[~deepqmc.gnn.electron_gnn.NucleiEmbedding]):
optional, the instance responsible for creating the initial nuclear
embeddings. Set to :data:`None` if nuclear embeddings are not needed.
electron_embedding (~typing.Type[~deepqmc.gnn.electron_gnn.ElectronEmbedding]):
the instance that creates the initial electron embeddings.
layer_factory (~typing.Type[~deepqmc.gnn.electron_gnn.ElectronGNNLayer]): a
callable that generates a layer of the GNN.
ghost_coords (jax.Array): optional, specifies the coordinates of one or more
ghost atoms, useful for breaking spatial symmetries of the nuclear geometry.
"""
def __init__(
self,
hamil,
embedding_dim,
*,
n_interactions,
edge_features,
self_interaction,
two_particle_stream_dim,
nuclei_embedding,
electron_embedding,
layer_factory,
ghost_coords=None,
):
super().__init__()
n_nuc, n_up, n_down = hamil.n_nuc, hamil.n_up, hamil.n_down
n_atom_types = hamil.mol.n_atom_types
charges = hamil.mol.charges
self.ghost_coords = None
if ghost_coords is not None:
charges = jnp.concatenate([charges, jnp.zeros(len(ghost_coords))])
n_nuc += len(ghost_coords)
n_atom_types += 1
self.ghost_coords = jnp.asarray(ghost_coords)
self.n_nuc, self.n_up, self.n_down = n_nuc, n_up, n_down
self.embedding_dim = embedding_dim
self.node_data = {
'n_nodes': {'nuclei': n_nuc, 'electrons': n_up + n_down},
'n_node_types': {'electrons': 1 if n_up == n_down else 2},
'node_types': {
'electrons': jnp.array(n_up * [0] + n_down * [int(n_up != n_down)])
},
}
self.edge_types = tuple((edge_features or {}).keys())
self.layers = [
layer_factory(
n_interactions,
ilayer,
n_nuc,
n_up,
n_down,
embedding_dim,
self.edge_types,
self_interaction,
self.node_data,
two_particle_stream_dim,
)
for ilayer in range(n_interactions)
]
self.edge_features = edge_features
self.nuclei_embedding = (
nuclei_embedding(n_up, n_down, charges, n_atom_types)
if nuclei_embedding
else None
)
self.electron_embedding = electron_embedding(
n_nuc,
n_up,
n_down,
embedding_dim,
self.node_data['n_node_types']['electrons'],
self.node_data['node_types']['electrons'],
)
self.self_interaction = self_interaction
def node_factory(self, phys_conf):
nucleus_embedding = (
self.nuclei_embedding(phys_conf) if self.nuclei_embedding else None
)
electron_embedding = self.electron_embedding(phys_conf, nucleus_embedding)
return GraphNodes(nucleus_embedding, electron_embedding)
[docs]
def edge_factory(self, phys_conf):
r"""Compute all the graph edges used in the GNN."""
edge_factory = MolecularGraphEdgeBuilder(
self.n_nuc,
self.n_up,
self.n_down,
self.edge_types,
self_interaction=self.self_interaction,
)
edges = edge_factory(phys_conf)
return {
typ: edges[typ].update_from_single_array(
self.edge_features[typ](edges[typ].single_array)
)
for typ in self.edge_types
}
def __call__(self, phys_conf):
r"""
Execute the graph neural network.
Args:
phys_conf (PhysicalConfiguration): the physical configuration
of the molecule.
Returns:
float, (:math:`N_\text{elec}`, :data:`embedding_dim`):
the final embeddings of the electrons.
"""
if self.ghost_coords is not None:
phys_conf = phys_conf._replace(
R=jnp.concatenate(
[
phys_conf.R,
jnp.tile(self.ghost_coords[None], (len(phys_conf.R), 1, 1)),
],
axis=-2,
)
)
graph_edges = self.edge_factory(phys_conf)
graph_nodes = self.node_factory(phys_conf)
graph = Graph(graph_nodes, graph_edges)
for layer in self.layers:
graph = layer(graph)
return graph.nodes
[docs]
class NucleiEmbedding(hk.Module):
r"""Create initial embeddings for nuclei.
Args:
n_up (int): the number of spin up electrons.
n_down (int): the number of spin down electrons.
charges (jax.Array): the nuclear charges of the molecule.
n_atom_types (int): the number of different atom types in the molecule.
embedding_dim (int): the length of the output embedding vector
atom_type_embedding (bool): if :data:`True`, initial embeddings are the same
for atoms of the same type (nuclear charge), otherwise they are different
for all nuclei.
subnet_type (str): the type of subnetwork to use for the embedding generation:
- ``'mlp'``: an MLP is used
- ``'embed'``: a :class:`haiku.Embed` block is used
edge_features (~deepqmc.gnn.edge_features.EdgeFeature): optional, the edge
features to use when constructing the initial nuclear embeddings.
"""
def __init__(
self,
n_up,
n_down,
charges,
n_atom_types,
*,
embedding_dim,
atom_type_embedding,
subnet_type,
edge_features,
):
super().__init__()
assert subnet_type in ['mlp', 'embed']
self.edge_features = edge_features
if self.edge_features:
self.edge_factory = MolecularGraphEdgeBuilder(
len(charges),
n_up,
n_down,
['nn'],
self_interaction=True,
)
self.edge_mlp = MLP(
32,
'edge_mlp',
hidden_layers=(32,),
bias=True,
last_linear=True,
activation=jax.nn.silu,
init='ferminet',
)
self.embed_mlp = MLP(
embedding_dim,
'embed_mlp',
hidden_layers=(embedding_dim,),
bias=True,
last_linear=True,
activation=jax.nn.silu,
init='ferminet',
)
self.charge_embedding = jnp.tile(
jax.nn.one_hot(
jnp.unique(charges, size=len(charges), return_inverse=True)[-1],
len(charges),
)[:, None],
(1, len(charges), 1),
)
n_nuc_types = n_atom_types if atom_type_embedding else len(charges)
if subnet_type == 'mlp':
self.subnet = MLP(
embedding_dim,
hidden_layers=['log', 1],
bias=True,
last_linear=False,
activation=jnp.tanh,
init='deeperwin',
)
elif subnet_type == 'embed':
self.subnet = hk.Embed(n_nuc_types, embedding_dim)
self.input = (
jnp.arange(len(charges))
if not atom_type_embedding
else (
charges
if subnet_type == 'mlp'
else jnp.unique(charges, size=len(charges), return_inverse=True)[-1]
)
)
if subnet_type == 'mlp':
self.input = self.input[:, None]
def __call__(self, phys_conf):
if self.edge_features:
nn_features = self.edge_features(
self.edge_factory(phys_conf)['nn'].single_array
)
nn_features = jnp.concatenate([nn_features, self.charge_embedding], axis=-1)
nn_edges = self.edge_mlp(nn_features)
return self.embed_mlp(nn_edges.sum(axis=0))
else:
return self.subnet(self.input)
[docs]
class ElectronEmbedding(hk.Module):
r"""Create initial embeddings for electrons.
Args:
n_nuc (int): the number of nuclei.
n_up (int): the number of spin up electrons.
n_down (int): the number of spin down electrons.
embedding_dim (int): the desired length of the embedding vectors.
n_elec_types (int): the number of electron types to differentiate.
Usual values are:
- ``1``: treat all electrons as indistinguishable. Note that electrons
with different spins can still become distinguishable during the later
embedding update steps of the GNN.
- ``2``: treat spin up and spin down electrons as distinguishable already
in the initial embeddings.
elec_types (jax.Array): an integer array with length equal to the number of
electrons, with entries between ``0`` and ``n_elec_types``. Specifies the
type for each electron.
positional_embeddings (dict): optional, if not ``None``, a ``dict`` with edge
types as keys, and edge features as values. Specifies the edge types and
edge features to use when constructing the positional initial electron
embeddings.
use_spin (bool): only relevant if ``positional_embeddings`` is not ``False``,
if ``True``, concatenate the spin of the given electron after the
positional embedding features.
project_to_embedding_dim (bool): only relevant if ``positional_embeddings``
is not ``False``, if ``True``, use a linear layer to project the initial
embeddings to have length ``embedding_dim``.
"""
def __init__(
self,
n_nuc,
n_up,
n_down,
embedding_dim,
n_elec_types,
elec_types,
*,
positional_embeddings,
use_spin,
project_to_embedding_dim,
):
super().__init__()
self.n_nuc = n_nuc
self.n_up = n_up
self.n_down = n_down
self.embedding_dim = embedding_dim
self.n_elec_types = n_elec_types
self.elec_types = elec_types
self.positional_embeddings = positional_embeddings
self.use_spin = use_spin
self.project_to_embedding_dim = project_to_embedding_dim
def __call__(self, phys_conf, nucleus_embedding):
if self.positional_embeddings:
edge_factory = MolecularGraphEdgeBuilder(
self.n_nuc,
self.n_up,
self.n_down,
self.positional_embeddings.keys(),
self_interaction=False,
)
feats = tree_util.tree_map(
lambda f, e: f(e.single_array)
.swapaxes(0, 1)
.reshape(self.n_up + self.n_down, -1),
self.positional_embeddings,
edge_factory(phys_conf),
)
x = tree_util.tree_reduce(partial(jnp.concatenate, axis=1), feats)
if self.use_spin:
spins = jnp.concatenate([jnp.ones(self.n_up), -jnp.ones(self.n_down)])[
:, None
]
x = jnp.concatenate([x, spins], axis=1)
if self.project_to_embedding_dim:
x = hk.Linear(self.embedding_dim, with_bias=False)(x)
else:
X = hk.Embed(
self.n_elec_types, self.embedding_dim, name='ElectronicEmbedding'
)
x = X(self.elec_types)
return x
class PermutationInvariantEmbedding(hk.Module):
r"""Electron embeddings that are invariant to exchanges of identical nuclei."""
def __init__(
self,
n_nuc,
n_up,
n_down,
embedding_dim,
n_elec_types,
elec_types,
charges,
*,
edge_dim,
edge_features,
nuclear_charge_dependence,
use_spin,
):
assert nuclear_charge_dependence in {'concatenate', 'elementwise-product'}
super().__init__()
self.n_up = n_up
self.n_down = n_down
self.embedding_dim = embedding_dim
self.edge_factory = MolecularGraphEdgeBuilder(
n_nuc,
n_up,
n_down,
['ne'],
self_interaction=False,
)
self.edge_features = edge_features
self.nuclear_charge_dependence = nuclear_charge_dependence
self.charge_embedding = jax.nn.one_hot(
jnp.unique(charges, size=len(charges), return_inverse=True)[-1],
len(charges),
)
self.use_spin = use_spin
if nuclear_charge_dependence == 'elementwise-product':
self.charge_linear = hk.Linear(edge_dim, name='edge_linear', with_bias=True)
self.edge_linear = hk.Linear(edge_dim, with_bias=True)
else:
self.charge_embedding = jnp.tile(
self.charge_embedding[:, None], (1, n_up + n_down, 1)
)
self.edge_mlp = MLP(
edge_dim,
'edge_mlp',
hidden_layers=(edge_dim,),
bias=True,
last_linear=True,
activation=jax.nn.silu,
init='ferminet',
)
self.embed_mlp = MLP(
embedding_dim,
'embed_mlp',
hidden_layers=(embedding_dim,),
bias=True,
last_linear=True,
activation=jax.nn.silu,
init='ferminet',
)
def __call__(self, phys_conf, nucleus_embedding):
ne_features = self.edge_features(
self.edge_factory(phys_conf)['ne'].single_array
)
if self.nuclear_charge_dependence == 'elementwise-product':
ne_edges = (
jax.nn.sigmoid(self.edge_linear(ne_features))
* self.charge_linear(self.charge_embedding)[..., None, :]
)
else:
nucleus_embedding = (
self.charge_embedding
if nucleus_embedding is None
else jnp.tile(
nucleus_embedding[:, None, :], (1, self.n_up + self.n_down, 1)
)
)
ne_features = jnp.concatenate([ne_features, nucleus_embedding], axis=-1)
ne_edges = self.edge_mlp(ne_features)
electron_features = ne_edges.sum(axis=0)
if self.use_spin:
spins = jnp.concatenate([jnp.ones(self.n_up), -jnp.ones(self.n_down)])[
:, None
]
electron_features = jnp.concatenate([electron_features, spins], axis=1)
return self.embed_mlp(electron_features)