Source code for deepqmc.hkext

import haiku as hk
import jax.numpy as jnp
from haiku.initializers import VarianceScaling
from jax import tree_util
from jax.nn import softplus


[docs]def ssp(x): r""" Compute the shifted softplus activation function. Computes the elementwise function :math:`\text{softplus}(x)=\log(1+\text{e}^x)+\log\frac{1}{2}` """ return softplus(x) + jnp.log(0.5)
[docs]class MLP(hk.Module): r""" Represent a multilayer perceptron. Args: in_dim (int): the input dimension. out_dim (int): the output dimension. residual (bool): whether to include a residual connection name (str): optional, the name of the network. hidden_layers (tuple): optional, either ('log', :math:`N_\text{layers}`), in which case the network will have :math:`N_\text{layers}` layers with logarithmically changing widths, or a tuple of ints specifying the width of each layer. bias (str): optional, specifies which layers should have a bias term. Possible values are - :data:`True`: all layers will have a bias term - :data:`False`: no layers will have a bias term - ``'not_last'``: all but the last layer will have a bias term last_linear (bool): optional, if :data:`True` the activation function is not applied to the activation of the last layer. activation (Callable): optional, the activation function. w_init (str or Callable): optional, specifies the initialization of the linear weights. Possible string values are: - ``'default'``: the default haiku initialization method is used. - ``'deeperwin'``: the initialization method of the :class:`deeperwin` package is used. """ def __init__( self, out_dim, name=None, *, hidden_layers, bias, last_linear, activation, init, ): assert bias in (True, False, 'not_last') super().__init__(name=name) self.activation = activation self.last_linear = last_linear self.bias = bias self.out_dim = out_dim if isinstance(init, str): self.w_init = { 'deeperwin': VarianceScaling(1.0, 'fan_avg', 'uniform'), 'default': VarianceScaling(1.0, 'fan_in', 'truncated_normal'), 'ferminet': VarianceScaling(1.0, 'fan_in', 'normal'), }[init] self.b_init = { 'deeperwin': lambda s, d: jnp.zeros(shape=s, dtype=d), 'default': lambda s, d: jnp.zeros(shape=s, dtype=d), 'ferminet': VarianceScaling(1.0, 'fan_out', 'normal'), }[init] else: self.w_init = init self.b_init = init self.hidden_layers = hidden_layers or [] def __call__(self, inputs): if len(self.hidden_layers) == 2 and self.hidden_layers[0] == 'log': n_hidden = self.hidden_layers[1] qs = [k / n_hidden for k in range(1, n_hidden + 1)] dims = [round(inputs.shape[-1] ** (1 - q) * self.out_dim**q) for q in qs] else: dims = [*self.hidden_layers, self.out_dim] n_layers = len(dims) layers = [] for idx, dim in enumerate(dims): with_bias = self.bias is True or ( self.bias == 'not_last' and idx < (n_layers - 1) ) layers.append( hk.Linear( output_size=dim, with_bias=with_bias, name='linear_%d' % idx, w_init=self.w_init, b_init=self.b_init, ) ) out = inputs for i, layer in enumerate(layers): out = layer(out) if i < (n_layers - 1) or not self.last_linear: out = self.activation(out) return out
[docs]class ResidualConnection: r""" Represent a residual connection between pytrees. The residual connection is only added if :data:`inp` and :data:`update` have the same shape. Args: - normalize (bool): if :data:`True` the sum of :data:`inp` and :data:`update` is normalized with :data:`sqrt(2)`. """ def __init__(self, *, normalize): self.normalize = normalize def __call__(self, inp, update): def leaf_residual(x, y): if x.shape != y.shape: return y z = x + y return z / jnp.sqrt(2) if self.normalize else z return tree_util.tree_map(leaf_residual, inp, update)
[docs]class SumPool: r"""Represent a global sum pooling operation.""" def __init__(self, out_dim, name=None): assert out_dim == 1 def __call__(self, x): return tree_util.tree_map(lambda leaf: leaf.sum(axis=-1, keepdims=True), x)
[docs]class Identity: r"""Represent the identity operation.""" def __init__(self, *args, **kwargs): pass def __call__(self, x): return x