Source code for deepqmc.hkext

from collections.abc import Callable, Sequence
from typing import Optional, Union

import haiku as hk
import jax
import jax.numpy as jnp
from haiku.initializers import VarianceScaling
from jax import tree_util
from jax.nn import sigmoid, softplus


[docs] def ssp(x: jax.Array) -> jax.Array: r"""Compute the shifted softplus activation function. Computes the elementwise function :math:`\text{softplus}(x)=\log(1+\text{e}^x)+\log\frac{1}{2}` """ return softplus(x) + jnp.log(0.5)
[docs] class MLP(hk.Module): r"""Represent a multilayer perceptron. Args: out_dim (int): the output dimension. name (str): optional, the name of the network. hidden_layers (tuple): optional, either ('log', :math:`N_\text{layers}`), in which case the network will have :math:`N_\text{layers}` layers with logarithmically changing widths, or a tuple of ints specifying the width of each layer. bias (bool | str): optional, specifies which layers should have a bias term. Possible values are - :data:`True`: all layers will have a bias term - :data:`False`: no layers will have a bias term - ``'not_last'``: all but the last layer will have a bias term last_linear (bool): optional, if :data:`True` the activation function is not applied to the activation of the last layer. activation (~collections.abc.Callable): optional, the activation function. init (str | Callable): optional, specifies the initialization of the linear weights. Possible string values are: - ``'default'``: the default haiku initialization method is used. - ``'ferminet'``: the initialization method of the :class:`ferminet` package is used. - ``'deeperwin'``: the initialization method of the :class:`deeperwin` package is used. """ def __init__( self, out_dim: int, name: Optional[str] = None, *, hidden_layers: Sequence[Union[int, str]], bias: bool, last_linear: bool, activation: Callable[[jax.Array], jax.Array], init: Union[str, Callable], ): assert bias in (True, False, 'not_last') super().__init__(name=name) self.activation = activation self.last_linear = last_linear self.bias = bias self.out_dim = out_dim if isinstance(init, str): self.w_init = { 'deeperwin': VarianceScaling(1.0, 'fan_avg', 'uniform'), 'default': VarianceScaling(1.0, 'fan_in', 'truncated_normal'), 'ferminet': VarianceScaling(1.0, 'fan_in', 'normal'), }[init] self.b_init = { 'deeperwin': lambda s, d: jnp.zeros(shape=s, dtype=d), 'default': lambda s, d: jnp.zeros(shape=s, dtype=d), 'ferminet': VarianceScaling(1.0, 'fan_out', 'normal'), }[init] else: self.w_init = init self.b_init = init self.hidden_layers = hidden_layers or [] def __call__(self, inputs: jax.Array) -> jax.Array: if len(self.hidden_layers) == 2 and self.hidden_layers[0] == 'log': assert isinstance(self.hidden_layers[1], int) n_hidden = self.hidden_layers[1] qs = [k / n_hidden for k in range(1, n_hidden + 1)] dims = [round(inputs.shape[-1] ** (1 - q) * self.out_dim**q) for q in qs] else: dims = [*self.hidden_layers, self.out_dim] n_layers = len(dims) layers = [] for idx, dim in enumerate(dims): with_bias = self.bias is True or ( self.bias == 'not_last' and idx < (n_layers - 1) ) layers.append( hk.Linear( output_size=dim, with_bias=with_bias, name='linear_%d' % idx, w_init=self.w_init, b_init=self.b_init, ) ) out = inputs for i, layer in enumerate(layers): out = layer(out) if i < (n_layers - 1) or not self.last_linear: out = self.activation(out) return out
[docs] class ResidualConnection: r"""Represent a residual connection between pytrees. The residual connection is only added if :data:`inp` and :data:`update` have the same shape. Args: - normalize (bool): if :data:`True` the sum of :data:`inp` and :data:`update` is normalized with :data:`sqrt(2)`. """ def __init__(self, *, normalize: bool): self.normalize = normalize def __call__(self, inp, update): def leaf_residual(x, y): if x.shape != y.shape: return y z = x + y return z / jnp.sqrt(2) if self.normalize else z return tree_util.tree_map(leaf_residual, inp, update)
[docs] class SumPool: r"""Represent a global sum pooling operation. Args: out_dim (int): the output dimension. name (str): optional, the name of the network. """ def __init__(self, out_dim, name=None): assert out_dim == 1 def __call__(self, x): return tree_util.tree_map(lambda leaf: leaf.sum(axis=-1, keepdims=True), x)
[docs] class Identity: r"""Represent the identity operation.""" def __init__(self, *args, **kwargs): pass def __call__(self, x): return x
[docs] class GLU(hk.Module): r"""Gated Linear Unit. Args: out_dim (int): the output dimension. name (str): optional, the name of the network. bias (bool): optional, whether to include a bias term. layer_norm_before (bool): optional, whether to apply layer normalization before the GLU operation. activation (~collections.abc.Callable): default is sigmoid, the activation function. b_init (~collections.abc.Callable): default is zeros, the initialization function for the bias term. """ def __init__( self, out_dim: int, name: Optional[str] = None, *, bias: bool = True, layer_norm_before: bool = True, activation: Callable[[jax.Array], jax.Array] = sigmoid, b_init: Callable = jnp.zeros, ): super().__init__(name=name) self.activated_linear = hk.Linear( out_dim, name='W', with_bias=bias, b_init=b_init ) self.linear = hk.Linear(out_dim, name='V', with_bias=bias, b_init=b_init) self.activation = activation self.layer_norm_before = layer_norm_before def __call__(self, x, y): if self.layer_norm_before: x = hk.LayerNorm(-1, False, False)(x) y = hk.LayerNorm(-1, False, False)(y) return self.activation(self.activated_linear(x)) * self.linear(y)