Source code for deepqmc.wf.base

import logging
import operator

import haiku as hk
import jax
import jax.numpy as jnp

from deepqmc.parallel import replicate_on_devices

__all__ = ()

log = logging.getLogger(__name__)


def init_wf_params(rng, hamil, ansatz):
    rng_sample, rng_params = jax.random.split(rng)
    try:
        # QC
        R_shape = (len(hamil.mol.charges), 3)
    except AttributeError:
        # QHO
        R_shape = 0
    phys_conf = hamil.init_sample(rng_sample, jnp.zeros(R_shape), 1)[0]
    params = ansatz.init(rng_params, phys_conf)

    num_params = jax.tree_util.tree_reduce(
        operator.add, jax.tree_util.tree_map(lambda x: x.size, params)
    )
    log.info(f'Number of model parameters: {num_params}')
    params = replicate_on_devices(params)
    return params


[docs]class WaveFunction(hk.Module): r""" Base class for all trial wave functions. Shape: - Input, :math:`\mathbf r`, (float, :math:`(N,3)`, a.u.): particle coordinates - Output1, :math:`\ln|\psi(\mathbf r)|` (float): - Output2, :math:`\operatorname{sgn}\psi(\mathbf r)` (float): """ def __init__(self, hamil): super().__init__() self.mol = hamil.mol self.n_up, self.n_down = hamil.n_up, hamil.n_down @property def spin_slices(self): return slice(None, self.n_up), slice(self.n_up, None) def forward(self, rs): return NotImplemented