Source code for deepqmc.wf.nn_wave_function

from typing import Literal

import haiku as hk
import jax
import jax.numpy as jnp

from ..physics import pairwise_diffs, pairwise_self_distance
from ..types import Psi
from ..utils import flatten, triu_flat

__all__ = ['NeuralNetworkWaveFunction']


class BackflowOp(hk.Module):
    def __init__(self, mult_act=None, add_act=None, with_envelope=True):
        super().__init__()
        self.mult_act = mult_act or (lambda x: 1 + 2 * jnp.tanh(x / 4))
        self.add_act = add_act or (lambda x: 0.1 * jnp.tanh(x / 4))
        self.with_envelope = with_envelope

    def __call__(self, xs, fs_mult, fs_add, dists_nuc):
        if self.with_envelope:
            envel = jnp.sqrt((xs**2).sum(axis=(-1, -3), keepdims=True))
        else:
            envel = 1
        if fs_mult is not None:
            xs = xs * self.mult_act(fs_mult)
        if fs_add is not None:
            R = dists_nuc.min(axis=-1) / 0.5
            cutoff = jnp.where(R < 1, R**2 * (6 - 8 * R + 3 * R**2), jnp.ones_like(R))
            assert isinstance(cutoff, jax.Array)
            xs = xs + cutoff[None, :, None] * envel * self.add_act(fs_add)
        return xs


def eval_log_slater(xs):
    if xs.shape[-1] == 0:
        return jnp.ones(xs.shape[:-2]), jnp.zeros(xs.shape[:-2])
    return jnp.linalg.slogdet(xs)


[docs] class NeuralNetworkWaveFunction(hk.Module): r""" Implements the neural network wave function. Configuration files to obtain the PauliNet [HermannNC20]_, FermiNet [PfauPRR20]_, DeepErwin [Gerard22]_ and PsiFormer [Glehn22]_ architectures are provided. For a detailed description of the implemented architecture see [Schaetzle23]_. Args: hamil (~deepqmc.hamil.MolecularHamiltonian): the Hamiltonian of the system. omni_factory (~collections.abc.Callable): creates the omni net. envelope (~deepqmc.wf.env.ExponentialEnvelopes): the orbital envelopes. backflow_op (~collections.abc.Callable): specifies how the backflow is applied to the orbitals. n_determinants (int): specifies the number of determinants full_determinant (bool): if :data:`False`, the determinants are factorized into spin-up and spin-down parts. cusp_electrons (~collections.abc.Callable): constructor of the electronic cusp module. cusp_nuclei (~collections.abc.Callable): constructor of the nuclear cusp module. backflow_transform (str): describes the backflow transformation. Possible values: - ``'mult'``: the backflow is a multiplicative factor - ``'add'``: the backflow is an additive term - ``'both'``: the backflow consist of a multiplicative factor and an additive term conf_coeff (~collections.abc.Callable): returns a function that combines the determinants to obtain the WF value """ def __init__( self, hamil, *, omni_factory, envelope, backflow_op, n_determinants, full_determinant, cusp_electrons, cusp_nuclei, backflow_transform: Literal['mult', 'add', 'both'], conf_coeff, ): super().__init__() self.mol = hamil.mol self.n_up, self.n_down = hamil.n_up, hamil.n_down self.charges = hamil.mol.charges n_up, n_down = self.n_up, self.n_down self.n_det = n_determinants self.full_determinant = full_determinant self.envelope = envelope(hamil, n_determinants) self.conf_coeff = conf_coeff(1, name='conf_coeff') self.cusp_electrons = cusp_electrons() if cusp_electrons else None self.cusp_nuclei = cusp_nuclei(hamil.mol.charges) if cusp_nuclei else None backflow_spec = [ *((n_up + n_down, n_up + n_down) if full_determinant else (n_up, n_down)), n_determinants, 2 if backflow_transform == 'both' else 1, ] self.backflow_transform = backflow_transform self.backflow_op = backflow_op() if backflow_op else None self.omni = omni_factory(hamil, *backflow_spec) if omni_factory else None @property def spin_slices(self): return slice(None, self.n_up), slice(self.n_up, None) def _backflow_op(self, xs, fs, dists_nuc): assert self.backflow_op is not None if self.backflow_transform == 'mult': fs_mult, fs_add = fs, None elif self.backflow_transform == 'add': fs_mult, fs_add = None, fs elif self.backflow_transform == 'both': fs_mult, fs_add = jnp.split(fs, 2, axis=0) else: fs_mult, fs_add = None, None fs_add = fs_add.squeeze(axis=0) if fs_add is not None else fs_add fs_mult = fs_mult.squeeze(axis=0) if fs_mult is not None else fs_mult return self.backflow_op(xs, fs_mult, fs_add, dists_nuc) def __call__(self, phys_conf, return_mos=False): diffs_nuc = pairwise_diffs(phys_conf.r, phys_conf.R) dists_nuc = jnp.sqrt(diffs_nuc[..., -1]) dists_elec = pairwise_self_distance(phys_conf.r, full=True) jastrow, fs, nuc_params = ( self.omni(phys_conf) if self.omni else (None, None, None) ) orb = self.envelope(phys_conf, nuc_params) orb_up, orb_down = ( (orb, orb) if self.full_determinant else jnp.split(orb, [self.n_up], axis=-1) ) orb_up, orb_down = orb_up[:, : self.n_up], orb_down[:, self.n_up :] if fs is not None: orb_up = self._backflow_op(orb_up, fs[0], dists_nuc[: self.n_up]) orb_down = self._backflow_op(orb_down, fs[1], dists_nuc[self.n_up :]) if return_mos: return orb_up, orb_down if self.full_determinant: sign, xs = eval_log_slater(jnp.concatenate([orb_up, orb_down], axis=-2)) else: sign_up, det_up = eval_log_slater(orb_up) sign_down, det_down = eval_log_slater(orb_down) sign, xs = sign_up * sign_down, det_up + det_down xs_shift = xs.max() # the exp-normalize trick, to avoid over/underflow of the exponential xs_shift = jnp.where(~jnp.isinf(xs_shift), xs_shift, jnp.zeros_like(xs_shift)) # replace -inf shifts, to avoid running into nans (see sloglindet) xs = sign * jnp.exp(xs - xs_shift) psi = self.conf_coeff(xs).squeeze() log_psi = jnp.log(jnp.abs(psi)) + xs_shift sign_psi = jax.lax.stop_gradient(jnp.sign(psi)) if self.cusp_electrons: same_dists = jnp.concatenate( [triu_flat(dists_elec[idxs, idxs]) for idxs in self.spin_slices], axis=-1, ) anti_dists = flatten(dists_elec[: self.n_up, self.n_up :]) log_psi += self.cusp_electrons(same_dists, anti_dists) if self.cusp_nuclei: log_psi += self.cusp_nuclei(dists_nuc) if jastrow is not None: log_psi = log_psi + jastrow return Psi(sign_psi, log_psi)